Skip to content

[WIP] RNN-T + MBR training.#593

Closed
pkufool wants to merge 7 commits into
k2-fsa:masterfrom
pkufool:mbr
Closed

[WIP] RNN-T + MBR training.#593
pkufool wants to merge 7 commits into
k2-fsa:masterfrom
pkufool:mbr

Conversation

@pkufool

@pkufool pkufool commented Sep 29, 2022

Copy link
Copy Markdown
Collaborator

This PR depends on k2-fsa/k2#1057 in k2.

@pkufool pkufool requested a review from yaozengwei December 8, 2022 05:41
@pkufool

pkufool commented Dec 8, 2022

Copy link
Copy Markdown
Collaborator Author

The model structure is like the diagram below, it has two joiners, one is the joiner for regular RNN-T, the other is quasi-joiner that produces the expected wer. To make the quasi-joiner work well, we use an Enhanced embedding instead of the Encoder output. The Embedding enhancer is some kind of model that has self-attention from masked_encoder_output and cross-attention from text_embedding produced by a tranformer LM.

image


self.encoder_output_layer = ScaledLinear(
d_model, num_classes, bias=True
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformer lm is actually an Embedding Layer plus TransformerEncoder that encode the symbols into text_embedding.

dropout=dropout,
layer_dropout=layer_dropout,
)
self.enhancer = TransformerDecoder(decoder_layer, num_layers)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EmbeddingEnhancer is a TransformerDecoder that has self-attention from masked_encoder_output and cross-attention from text_embedding.

N, T, C = embedding.shape
mask = torch.randn((N, T, C), device=embedding.device)
mask = mask > mask_proportion
masked_embedding = torch.masked_fill(embedding, ~mask, 0.0)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I randomly mask the encoder output here.

)
return init_context

def delta_wer(

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function implements the sampling process.

+ l2_loss_scale * l2_loss
+ delta_wer_scale * delta_wer_loss
+ predictor_loss_scale * predictor_loss
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The losses are combined here.

@pkufool

pkufool commented Dec 8, 2022

Copy link
Copy Markdown
Collaborator Author

@danpovey @yaozengwei @glynpu Would you please to have a look at this, if there is anything unclear, please let me know. Thanks!

@yaozengwei

Copy link
Copy Markdown
Collaborator

Sure. I will have a look.

@pkufool pkufool closed this Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants