[WIP] RNN-T + MBR training.#593
Conversation
|
The model structure is like the diagram below, it has two joiners, one is the joiner for regular RNN-T, the other is |
|
|
||
| self.encoder_output_layer = ScaledLinear( | ||
| d_model, num_classes, bias=True | ||
| ) |
There was a problem hiding this comment.
The transformer lm is actually an Embedding Layer plus TransformerEncoder that encode the symbols into text_embedding.
| dropout=dropout, | ||
| layer_dropout=layer_dropout, | ||
| ) | ||
| self.enhancer = TransformerDecoder(decoder_layer, num_layers) |
There was a problem hiding this comment.
The EmbeddingEnhancer is a TransformerDecoder that has self-attention from masked_encoder_output and cross-attention from text_embedding.
| N, T, C = embedding.shape | ||
| mask = torch.randn((N, T, C), device=embedding.device) | ||
| mask = mask > mask_proportion | ||
| masked_embedding = torch.masked_fill(embedding, ~mask, 0.0) |
There was a problem hiding this comment.
I randomly mask the encoder output here.
| ) | ||
| return init_context | ||
|
|
||
| def delta_wer( |
There was a problem hiding this comment.
This function implements the sampling process.
| + l2_loss_scale * l2_loss | ||
| + delta_wer_scale * delta_wer_loss | ||
| + predictor_loss_scale * predictor_loss | ||
| ) |
There was a problem hiding this comment.
The losses are combined here.
|
@danpovey @yaozengwei @glynpu Would you please to have a look at this, if there is anything unclear, please let me know. Thanks! |
|
Sure. I will have a look. |

This PR depends on k2-fsa/k2#1057 in k2.