Skip to content
This repository was archived by the owner on Apr 29, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6b160aa
Added sortagrad strategy
giuseppeCoccia Nov 22, 2019
3751dde
Added tests for sortagrad
giuseppeCoccia Nov 22, 2019
e054ed9
Delete seq_to_seq_collate_fn_sorted and add sort flag in seq_to_seq_c…
giuseppeCoccia Nov 27, 2019
8bdddb2
Fix doc strings + add sortagrad reference
giuseppeCoccia Nov 27, 2019
a0b3afe
Add docs about arguments in samplers functions
giuseppeCoccia Nov 27, 2019
6092921
Avoid dataloader repetition
giuseppeCoccia Nov 27, 2019
852c574
Change API to define shuffle strategy
giuseppeCoccia Nov 28, 2019
8b9db16
Merge branch 'master' into sortagrad
giuseppeCoccia Nov 28, 2019
971552f
Add proto file for shuffle strategy
giuseppeCoccia Nov 28, 2019
2997628
rename sortagrad variable using snake_case + remove unused import
giuseppeCoccia Dec 2, 2019
25b787e
Fix test error caused by changed sorta_grad variable name
giuseppeCoccia Dec 2, 2019
7965f44
Move update of num_iterators before for loop in sampler + change sort…
giuseppeCoccia Dec 3, 2019
26282f1
Remove sort within single batches
giuseppeCoccia Dec 3, 2019
b46ed75
Add types + change doc syntaxt for hyperlink
giuseppeCoccia Dec 6, 2019
ac09f5a
Add sampler documentation to docs
giuseppeCoccia Dec 6, 2019
b4e0367
Update sampler tests to use Hypothesis
giuseppeCoccia Dec 6, 2019
4b683fa
Fix doc strings and parameter types
giuseppeCoccia Dec 12, 2019
ea6c2f6
Let the sampler doc be auto-generated
giuseppeCoccia Dec 12, 2019
e31db68
Make Hypothesis generate set of sequential epoch numbers
giuseppeCoccia Dec 12, 2019
e21db25
Add separate function to create list of sequential epoch numbers
giuseppeCoccia Dec 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/myrtlespeech/data/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
alphabet
dataset/index
preprocess
sampler
16 changes: 16 additions & 0 deletions docs/source/myrtlespeech/data/sampler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
=========
sampler
=========

.. automodule:: myrtlespeech.data.sampler
:members:
:show-inheritance:

.. autoclass:: myrtlespeech.data.sampler.SequentialRandomSampler
:members:
:show-inheritance:


.. autoclass:: myrtlespeech.data.sampler.SortaGrad
:members:
:show-inheritance:
Comment on lines +9 to +16

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be auto-generated to reduce chance of forgetting to update it in the future?

.. automodule:: myrtlespeech.data.sampler                                        
     :members:                                                                    
     :show-inheritance: 

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added this at the beginning of the file

36 changes: 29 additions & 7 deletions src/myrtlespeech/builders/task_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import multiprocessing
from typing import Tuple
from typing import Union

import torch
from myrtlespeech.builders.dataset import build as build_dataset
from myrtlespeech.builders.speech_to_text import build as build_stt
from myrtlespeech.data.batch import seq_to_seq_collate_fn
from myrtlespeech.data.sampler import RandomBatchSampler
from myrtlespeech.data.sampler import SequentialRandomSampler
from myrtlespeech.data.sampler import SortaGrad
from myrtlespeech.model.seq_to_seq import SeqToSeq
from myrtlespeech.protos import task_config_pb2

Expand Down Expand Up @@ -114,15 +116,35 @@ def target_transform(target):
add_seq_len_to_transforms=True,
)

shuffle = task_config.train_config.shuffle_batches_before_every_epoch
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_sampler=RandomBatchSampler(
shuffle_str = task_config.train_config.WhichOneof("shuffle_strategy")
batch_sampler: Union[SortaGrad, SequentialRandomSampler]
if shuffle_str == "sorta_grad":
batch_sampler = SortaGrad(
indices=range(len(train_dataset)),
batch_size=task_config.train_config.batch_size,
shuffle=True,
drop_last=False,
)
elif shuffle_str == "random_batches":
batch_sampler = SequentialRandomSampler(
indices=range(len(train_dataset)),
batch_size=task_config.train_config.batch_size,
shuffle=shuffle,
shuffle=True,
drop_last=False,
),
)
elif shuffle_str == "sequential_batches":
batch_sampler = SequentialRandomSampler(
indices=range(len(train_dataset)),
batch_size=task_config.train_config.batch_size,
shuffle=False,
drop_last=False,
)
else:
raise ValueError(f"unsupported shuffle strategy {shuffle_str}")

train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=seq_to_seq_collate_fn,
pin_memory=torch.cuda.is_available(),
Expand Down
3 changes: 2 additions & 1 deletion src/myrtlespeech/configs/deep_speech_1_en.config
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ train_config {
}
}
}
shuffle_batches_before_every_epoch: true;
random_batches {
}
}

eval_config {
Expand Down
3 changes: 2 additions & 1 deletion src/myrtlespeech/configs/deep_speech_2_en.config
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ train_config {
}
}
}
shuffle_batches_before_every_epoch: true;
sorta_grad {
}
}

eval_config {
Expand Down
91 changes: 84 additions & 7 deletions src/myrtlespeech/data/sampler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,48 @@
import random
from typing import Iterable
from typing import Optional

Comment thread
samgd marked this conversation as resolved.

class RandomBatchSampler:
"""TODO"""
class SequentialRandomSampler:
"""A sequential or random iterable over batches of indices.

def __init__(self, indices, batch_size, shuffle, drop_last=False):
The iterator used each time this iterable is iterated over will yield
Comment thread
giuseppeCoccia marked this conversation as resolved.
batches of indices either sequentially (i.e. in-order) or randomly (uniform
without replacement).
This iterable records the number of times it has returned an iterator. A
sequential iterator is returned if the current count is in `sequential`.

Args:
indices: Data with which batches are created.
batch_size: Batch dimension.
shuffle: Set to True to have the data reshuffled at every epoch if a
random iterator is used.
drop_last: Set to True to drop the last incomplete batch, if the
dataset size is not divisible by the batch size. If False and the
size of dataset is not divisible by the batch size, then the last
batch will be smaller.
n_iterators: Number of iterators returned so far.
sequential: Counts at which to return a sequential iterator.

Yields:
Batches from `indices`.
"""

def __init__(
self,
indices: Iterable[int],
batch_size: int,
shuffle: bool,
drop_last: bool = False,
n_iterators: int = 0,
sequential: Optional[set] = None,
):
self.shuffle = shuffle
self.batch_indices = self._batch_indices(
indices, batch_size, drop_last
)
self._n_iterators = n_iterators
self._sequential = sequential or set()

def _batch_indices(self, indices, batch_size, drop_last):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add types.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Types?

batches = []
Expand All @@ -23,10 +57,53 @@ def _batch_indices(self, indices, batch_size, drop_last):
return batches

def __iter__(self):
if self.shuffle:
random.shuffle(self.batch_indices)
for b in self.batch_indices:
yield b
indices = list(range(len(self.batch_indices)))
if self.shuffle and self._n_iterators not in self._sequential:
random.shuffle(indices)
self._n_iterators += 1
for index in indices:
yield self.batch_indices[index]

def __len__(self):
return len(self.batch_indices)


class SortaGrad(SequentialRandomSampler):
"""An iterable over batch indices according to the SortaGrad strategy.

The SortaGrad curriculum learning strategy iterates over batches from the
Comment thread
samgd marked this conversation as resolved.
batched dataset sequentially for the first pass and then randomly for all
other passes. See `Deep Speech 2 <https://arxiv.org/abs/1512.02595>`_ paper
for more information.

Args:
indices: Data with which batches are created.
batch_size: Batch dimension.
shuffle: Set to True to have the data reshuffled at every epoch if a
random iterator is used.
drop_last: Set to True to drop the last incomplete batch, if the
dataset size is not divisible by the batch size. If False and the
size of dataset is not divisible by the batch size, then the last
batch will be smaller.
start_epoch: Number of iterators returned so far by the sampler.

Yields:
Batches from `indices`.
"""

def __init__(
self,
indices: Iterable[int],
batch_size: int,
shuffle: bool,
drop_last: bool = False,
start_epoch: int = 0,
):
super().__init__(
indices,
batch_size,
shuffle,
drop_last,
n_iterators=start_epoch,
sequential={0},
)
15 changes: 15 additions & 0 deletions src/myrtlespeech/protos/shuffle_strategy.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
syntax = "proto3";

package myrtlespeech.protos;


message SequentialBatches {
}


message RandomBatches {
}


message SortaGrad {
}
9 changes: 7 additions & 2 deletions src/myrtlespeech/protos/train_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package myrtlespeech.protos;
import "myrtlespeech/protos/dataset.proto";
import "myrtlespeech/protos/lr_scheduler.proto";
import "myrtlespeech/protos/optimizer.proto";
import "myrtlespeech/protos/shuffle_strategy.proto";


// Configuration for training.
Expand All @@ -29,8 +30,12 @@ message TrainConfig {

Dataset dataset = 9;

oneof supported_shuffles {
oneof shuffle_strategy {
// Mantain a sequential batch order.
SequentialBatches sequential_batches = 10;
// Shuffle batches before every epoch.
bool shuffle_batches_before_every_epoch = 10;
RandomBatches random_batches = 11;
// Sequential for the first epoch and random for the following ones.
SortaGrad sorta_grad = 12;
}
}
2 changes: 1 addition & 1 deletion tests/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_pad_sequences_returns_tensor_with_correct_values(


def test_seq_to_seq_collate_fn() -> None:
"""Unit test to ensure seq_to_seq_collate_fn returns correct values."""
"""Ensure seq_to_seq_collate_fn returns correct values"""
inputs = [rand([1, 2, 3]), rand([1, 2, 5])]
seq_lens = tensor([3, 5])

Expand Down
Loading