This repository was archived by the owner on Apr 29, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Sortagrad #13
Open
giuseppeCoccia
wants to merge
20
commits into
master
Choose a base branch
from
sortagrad
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Sortagrad #13
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
6b160aa
Added sortagrad strategy
giuseppeCoccia 3751dde
Added tests for sortagrad
giuseppeCoccia e054ed9
Delete seq_to_seq_collate_fn_sorted and add sort flag in seq_to_seq_c…
giuseppeCoccia 8bdddb2
Fix doc strings + add sortagrad reference
giuseppeCoccia a0b3afe
Add docs about arguments in samplers functions
giuseppeCoccia 6092921
Avoid dataloader repetition
giuseppeCoccia 852c574
Change API to define shuffle strategy
giuseppeCoccia 8b9db16
Merge branch 'master' into sortagrad
giuseppeCoccia 971552f
Add proto file for shuffle strategy
giuseppeCoccia 2997628
rename sortagrad variable using snake_case + remove unused import
giuseppeCoccia 25b787e
Fix test error caused by changed sorta_grad variable name
giuseppeCoccia 7965f44
Move update of num_iterators before for loop in sampler + change sort…
giuseppeCoccia 26282f1
Remove sort within single batches
giuseppeCoccia b46ed75
Add types + change doc syntaxt for hyperlink
giuseppeCoccia ac09f5a
Add sampler documentation to docs
giuseppeCoccia b4e0367
Update sampler tests to use Hypothesis
giuseppeCoccia 4b683fa
Fix doc strings and parameter types
giuseppeCoccia ea6c2f6
Let the sampler doc be auto-generated
giuseppeCoccia e31db68
Make Hypothesis generate set of sequential epoch numbers
giuseppeCoccia e21db25
Add separate function to create list of sequential epoch numbers
giuseppeCoccia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,3 +9,4 @@ | |
| alphabet | ||
| dataset/index | ||
| preprocess | ||
| sampler | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| ========= | ||
| sampler | ||
| ========= | ||
|
|
||
| .. automodule:: myrtlespeech.data.sampler | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| .. autoclass:: myrtlespeech.data.sampler.SequentialRandomSampler | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
|
|
||
| .. autoclass:: myrtlespeech.data.sampler.SortaGrad | ||
| :members: | ||
| :show-inheritance: | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,7 +60,8 @@ train_config { | |
| } | ||
| } | ||
| } | ||
| shuffle_batches_before_every_epoch: true; | ||
| random_batches { | ||
| } | ||
| } | ||
|
|
||
| eval_config { | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -111,7 +111,8 @@ train_config { | |
| } | ||
| } | ||
| } | ||
| shuffle_batches_before_every_epoch: true; | ||
| sorta_grad { | ||
| } | ||
| } | ||
|
|
||
| eval_config { | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,48 @@ | ||
| import random | ||
| from typing import Iterable | ||
| from typing import Optional | ||
|
|
||
|
samgd marked this conversation as resolved.
|
||
|
|
||
| class RandomBatchSampler: | ||
| """TODO""" | ||
| class SequentialRandomSampler: | ||
| """A sequential or random iterable over batches of indices. | ||
|
|
||
| def __init__(self, indices, batch_size, shuffle, drop_last=False): | ||
| The iterator used each time this iterable is iterated over will yield | ||
|
giuseppeCoccia marked this conversation as resolved.
|
||
| batches of indices either sequentially (i.e. in-order) or randomly (uniform | ||
| without replacement). | ||
| This iterable records the number of times it has returned an iterator. A | ||
| sequential iterator is returned if the current count is in `sequential`. | ||
|
|
||
| Args: | ||
| indices: Data with which batches are created. | ||
| batch_size: Batch dimension. | ||
| shuffle: Set to True to have the data reshuffled at every epoch if a | ||
| random iterator is used. | ||
| drop_last: Set to True to drop the last incomplete batch, if the | ||
| dataset size is not divisible by the batch size. If False and the | ||
| size of dataset is not divisible by the batch size, then the last | ||
| batch will be smaller. | ||
| n_iterators: Number of iterators returned so far. | ||
| sequential: Counts at which to return a sequential iterator. | ||
|
|
||
| Yields: | ||
| Batches from `indices`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| indices: Iterable[int], | ||
| batch_size: int, | ||
| shuffle: bool, | ||
| drop_last: bool = False, | ||
| n_iterators: int = 0, | ||
| sequential: Optional[set] = None, | ||
| ): | ||
| self.shuffle = shuffle | ||
| self.batch_indices = self._batch_indices( | ||
| indices, batch_size, drop_last | ||
| ) | ||
| self._n_iterators = n_iterators | ||
| self._sequential = sequential or set() | ||
|
|
||
| def _batch_indices(self, indices, batch_size, drop_last): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add types.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Types? |
||
| batches = [] | ||
|
|
@@ -23,10 +57,53 @@ def _batch_indices(self, indices, batch_size, drop_last): | |
| return batches | ||
|
|
||
| def __iter__(self): | ||
| if self.shuffle: | ||
| random.shuffle(self.batch_indices) | ||
| for b in self.batch_indices: | ||
| yield b | ||
| indices = list(range(len(self.batch_indices))) | ||
| if self.shuffle and self._n_iterators not in self._sequential: | ||
| random.shuffle(indices) | ||
| self._n_iterators += 1 | ||
| for index in indices: | ||
| yield self.batch_indices[index] | ||
|
|
||
| def __len__(self): | ||
| return len(self.batch_indices) | ||
|
|
||
|
|
||
| class SortaGrad(SequentialRandomSampler): | ||
| """An iterable over batch indices according to the SortaGrad strategy. | ||
|
|
||
| The SortaGrad curriculum learning strategy iterates over batches from the | ||
|
samgd marked this conversation as resolved.
|
||
| batched dataset sequentially for the first pass and then randomly for all | ||
| other passes. See `Deep Speech 2 <https://arxiv.org/abs/1512.02595>`_ paper | ||
| for more information. | ||
|
|
||
| Args: | ||
| indices: Data with which batches are created. | ||
| batch_size: Batch dimension. | ||
| shuffle: Set to True to have the data reshuffled at every epoch if a | ||
| random iterator is used. | ||
| drop_last: Set to True to drop the last incomplete batch, if the | ||
| dataset size is not divisible by the batch size. If False and the | ||
| size of dataset is not divisible by the batch size, then the last | ||
| batch will be smaller. | ||
| start_epoch: Number of iterators returned so far by the sampler. | ||
|
|
||
| Yields: | ||
| Batches from `indices`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| indices: Iterable[int], | ||
| batch_size: int, | ||
| shuffle: bool, | ||
| drop_last: bool = False, | ||
| start_epoch: int = 0, | ||
| ): | ||
| super().__init__( | ||
| indices, | ||
| batch_size, | ||
| shuffle, | ||
| drop_last, | ||
| n_iterators=start_epoch, | ||
| sequential={0}, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| syntax = "proto3"; | ||
|
|
||
| package myrtlespeech.protos; | ||
|
|
||
|
|
||
| message SequentialBatches { | ||
| } | ||
|
|
||
|
|
||
| message RandomBatches { | ||
| } | ||
|
|
||
|
|
||
| message SortaGrad { | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be auto-generated to reduce chance of forgetting to update it in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added this at the beginning of the file