From 26cb6b41ba616fca480b267a0b60864703cc553c Mon Sep 17 00:00:00 2001 From: rj42 Date: Tue, 28 Mar 2023 20:44:36 +0300 Subject: [PATCH 1/2] week07/asr: fix requirements --- week07/asr/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/week07/asr/requirements.txt b/week07/asr/requirements.txt index b06766b..ecde543 100644 --- a/week07/asr/requirements.txt +++ b/week07/asr/requirements.txt @@ -1,5 +1,5 @@ pytorch-lightning==1.5.9 -hydra-core=1.1.1 +hydra-core==1.1.1 editdistance torchaudio torch From 22f6cd2e33f2839cdeae5f58a0d4402751652656 Mon Sep 17 00:00:00 2001 From: rj42 Date: Tue, 28 Mar 2023 20:46:07 +0300 Subject: [PATCH 2/2] week07/asr: fix eval phase in training step: turn off gradients and turn on eval mode --- week07/asr/src/model.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/week07/asr/src/model.py b/week07/asr/src/model.py index f331fef..2e26a00 100644 --- a/week07/asr/src/model.py +++ b/week07/asr/src/model.py @@ -57,19 +57,22 @@ def training_step(self, batch: Any, batch_nb: int): log = {"train_loss": loss, "lr": self.optimizers().param_groups[0]["lr"]} if (batch_nb + 1) % self.log_every_n_steps == 0: - - refs = self.decoder.decode(token_ids=targets, token_ids_length=target_len) - hyps = self.decoder.decode( - token_ids=preds, token_ids_length=encoded_len, unique_consecutive=True - ) - logger.info("reference : %s", refs[0]) - logger.info("prediction: %s", hyps[0]) - self.wer.update(references=refs, hypotheses=hyps) - wer, _, _ = self.wer.compute() - - self.wer.reset() - - log["train_wer"] = wer + with torch.no_grad(): + self.eval() + refs = self.decoder.decode(token_ids=targets, token_ids_length=target_len) + hyps = self.decoder.decode( + token_ids=preds, token_ids_length=encoded_len, unique_consecutive=True + ) + self.train() + + logger.info("reference : %s", refs[0]) + logger.info("prediction: %s", hyps[0]) + self.wer.update(references=refs, hypotheses=hyps) + wer, _, _ = self.wer.compute() + + self.wer.reset() + + log["train_wer"] = wer self.log_dict(log)