Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def get_parser():
'If it is negative, then rescore with the whole lattice.'\
'CAUTION: You have to reduce max_duration in case of CUDA OOM'
)
parser.add_argument(
'--output-dir',
type=str,
default='exp/',
help='output dir for err and recog text')
return parser


Expand Down Expand Up @@ -285,7 +290,8 @@ def main():
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers,
vgg_frontend=True)
vgg_frontend=True,
is_espnet_structure=True)
elif model_type == "contextnet":
model = ContextNet(
num_features=80,
Expand Down Expand Up @@ -378,7 +384,8 @@ def main():

# load dataset
librispeech = LibriSpeechAsrDataModule(args)
test_sets = ['test-clean', 'test-other']
# test_sets = ['test-clean', 'test-other']
test_sets = ['test-clean']
# test_sets = ['test-other']
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
logging.info(f'* DECODING: {test_set}')
Expand All @@ -393,13 +400,14 @@ def main():
use_whole_lattice=use_whole_lattice,
output_beam_size=output_beam_size)

recog_path = exp_dir / f'recogs-{test_set}.txt'
output_dir = Path(args.output_dir)
recog_path = output_dir / f'recogs-{test_set}.txt'
store_transcripts(path=recog_path, texts=results)
logging.info(f'The transcripts are stored in {recog_path}')

# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = exp_dir / f'errs-{test_set}.txt'
errs_filename = output_dir / f'errs-{test_set}.txt'
with open(errs_filename, 'w') as f:
write_error_stats(f, test_set, results)
logging.info('Wrote detailed error stats to {}'.format(errs_filename))
Expand Down
43 changes: 41 additions & 2 deletions egs/librispeech/asr/simple_v1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

set -eou pipefail

stage=0
stage=7

if [ $stage -le 1 ]; then
local/download_lm.sh "openslr.org/resources/11" data/local/lm
Expand Down Expand Up @@ -74,6 +74,45 @@ fi

if [ $stage -le 7 ]; then
# python3 ./decode.py # ctc decoding
python3 ./mmi_bigram_decode.py --epoch 9
# python3 ./mmi_bigram_decode.py --epoch 9
# python3 ./mmi_mbr_decode.py
mkdir -p exp/
export CUDA_VISIBLE_DEVICES=3
ln -sf /home/storage14/guoliyong/open-source/snowfall/egs/librispeech/asr/simple_v1/exp-conformer-noam-mmi-att-musan-sa-vgg ./
ln -sf /ceph-ly/open-source/lm_resocre_snowfall/snowfall/egs/librispeech/asr/simple_v1/exp/data ./exp/
ln -sf /ceph-ly/open-source/to_submit/espnet_snowfall/snowfall/egs/librispeech/asr/simple_v1/data ./
output_dir=result_dir
mkdir -p $output_dir
# no rescore
python3 ./mmi_att_transformer_decode.py \
--output-dir $output_dir \
--num-paths -1 \
--max-duration 300 \
--attention-dim 512 \
--use-lm-rescoring False \
--avg 16 \
--epoch 19

# lattice rescore
python3 ./mmi_att_transformer_decode.py \
--output-dir $output_dir \
--num-paths -1 \
--max-duration 300 \
--attention-dim 512 \
--use-lm-rescoring True \
--avg 16 \
--epoch 19
fi

if [ $stage -le 8 ]; then
export CUDA_VISIBLE_DEVICES=3
# nbest rescore
python3 ./mmi_att_transformer_decode.py \
--output-dir $output_dir \
--num-paths 1000 \
--max-duration 300 \
--attention-dim 512 \
--use-lm-rescoring True \
--avg 16 \
--epoch 19
fi
28 changes: 21 additions & 7 deletions snowfall/models/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,21 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int
d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048,
num_encoder_layers: int = 12, num_decoder_layers: int = 6,
dropout: float = 0.1, cnn_module_kernel: int = 31,
normalize_before: bool = True, vgg_frontend: bool = False) -> None:
normalize_before: bool = True, vgg_frontend: bool = False,
is_espnet_structure: bool = False) -> None:
super(Conformer, self).__init__(num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor,
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend)

self.encoder_pos = RelPositionalEncoding(d_model, dropout)

encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before)
encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before, is_espnet_structure)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)

def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor, Optional[Tensor]]:
"""
Expand All @@ -65,6 +70,8 @@ def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor
mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask != None else None
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before and self.is_espnet_structure:
x = self.after_norm(x)

return x, mask

Expand All @@ -90,9 +97,10 @@ class ConformerEncoderLayer(nn.Module):
"""

def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
cnn_module_kernel: int = 31, normalize_before: bool = True) -> None:
cnn_module_kernel: int = 31, normalize_before: bool = True,
is_espnet_structure=False) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure)

self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Expand Down Expand Up @@ -319,7 +327,8 @@ class RelPositionMultiheadAttention(nn.Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""

def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None:
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.,
is_espnet_structure: bool = False) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
Expand All @@ -338,6 +347,7 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None:
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))

self._reset_parameters()
self.is_espnet_structure = is_espnet_structure

def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
Expand Down Expand Up @@ -538,7 +548,8 @@ def multi_head_attention_forward(self, query: Tensor,
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)

q = q * scaling
if not self.is_espnet_structure:
q = q * scaling

if attn_mask is not None:
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
Expand Down Expand Up @@ -596,7 +607,10 @@ def multi_head_attention_forward(self, query: Tensor,
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)

attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2)
if not self.is_espnet_structure:
attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2)
else:
attn_output_weights = (matrix_ac + matrix_bd) * scaling # (batch, head, time1, time2)

attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)

Expand Down