diff --git a/examples/aishell/RESULT.md b/examples/aishell/RESULT.md index f035cc2ed06fafe7995dbaacf0dabbe52475fce3..573dac8e11c33c1b295b61d29d16ddbf43107442 100644 --- a/examples/aishell/RESULT.md +++ b/examples/aishell/RESULT.md @@ -32,6 +32,17 @@ | attention decoder | 5.50 | | attention rescoring | 5.39 | +## Unified Conformer Result + +* Feature info: using fbank feature, dither=0, cmvn, oneline speed perturb +* Training info: lr 0.001, acc_grad 1, 240 epochs, 4 Ascend910 +* Decoding info: ctc_weight 0.3, average_num 30 +* Performance result: total_time 12h17min + +| decoding mode/chunk size | full | 16 | 8 | 4 | +|---------------------------|-------|-------|-------|-------| +| ctc greedy search | 5.87 | 6.65 | 7.08 | 7.53 | + # Ascend310P Performance Record * Resource: Ascend310P; CPU 2.00GHz; 64cores; memory 250G; OS Euler2.8 diff --git a/examples/aishell/config/asr_unified_conformer.yaml b/examples/aishell/config/asr_unified_conformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..374d85fd1dd2bcce7753184ad91585c24bc92a96 --- /dev/null +++ b/examples/aishell/config/asr_unified_conformer.yaml @@ -0,0 +1,172 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + cnn_module_kernel: 15 + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + feature_norm : True + causal: True + cnn_module_norm: 'layer_norm' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: False + +# feature extraction +collate_conf: + # feature level config + feature_extraction_conf: + feature_type: 'fbank' + mel_bins: 80 + frame_shift: 10 + frame_length: 25 + using_pitch: False + feature_dither: 0.0 # add dither [-feature_dither,feature_dither] on fbank feature + # data augmentation config + use_speed_perturb: True + use_spec_aug: True + spec_aug_conf: + warp_for_time: False + num_t_mask: 2 + num_f_mask: 2 + prop_mask_t: 0.1 + prop_mask_f: 0.1 + max_t: 50 + max_f: 10 + max_w: 80 + use_dynamic_chunk: True + use_dynamic_left_chunk: False + decoding_chunk_size: 4 + static_chunk_size: 0 + num_decoding_left_chunks: -1 + +# dataset related +dataset_conf: + max_length: 3000 + min_length: 0 + token_max_length: 30 + token_min_length: 1 + batch_type: 'bucket' # bucket, static, dynamic + frame_bucket_limit: '144, 204, 288, 400, 512, 600, 712, 800, 912, 1024, 1112, 1200, 1400, 1600, 2000, 3000' + batch_bucket_limit: '40, 80, 80, 72, 72, 56, 56, 56, 40, 40, 40, 40, 24, 8, 8, 8' + batch_factor: 1 + shuffle: True + +# train option +grad_clip: 5 +accum_grad: 1 +max_epoch: 240 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +cmvn_file: "/path/train/global_cmvn" +is_json_cmvn: True + +exp_name: default +train_data: "/path/train/format.data" +eval_data: "/path/dev/format.data" +save_checkpoint: True +save_checkpoint_epochs: 1 +keep_checkpoint_max: 30 +save_checkpoint_path: "./" +device_target: "Ascend" +is_distributed: False +mixed_precision: True +resume_ckpt: "" +save_graphs: False +training_with_eval: True + +# decode option +test_data: "/path/test/format.data" +dict: "/path/dict/lang_char.txt" +decode_ckpt: "avg_30.ckpt" +decode_mode: "attention" # ctc_greedy_search,ctc_prefix_beam_search,attention,attention_rescoring +full_graph: True +decode_batch_size: 1 +ctc_weight: 0.0 +beam_size: 10 +penalty: 0.0 +connect_symbol: '' # use spaces to connect english words. +simulate_streaming: True + +test_dataset_conf: + max_length: 1200 + min_length: 0 + token_max_length: 30 + token_min_length: 1 + batch_type: 'bucket' # bucket, static, dynamic + frame_bucket_limit: '1200' + batch_bucket_limit: '40' + batch_factor: 1 + shuffle: False + +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "obs://speech/corpus/aishell1" # dataset path in OBS +train_url: "obs://speech/code/asr/workspace/" # workspace path in OBS +checkpoint_url: "" # pre-train ckpt in OBS +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +need_modelarts_dataset_unzip: False +modelarts_dataset_unzip_name: "corpora" +mnt_enable: False +ak: "" +sk: "" +server: "" +compile_url: "" + +# mindinsight config +enable_profiling: False +enable_summary: True + +# Config description for each option +use_dynamic_chunk: # 'whether to use dynamic chunk or not, default: False' +use_dynamic_left_chunk: # 'whether to use dynamic left chunk for training, default: False' +decoding_chunk_size: # 'ecoding chunk size for dynamic chunk, it is + # 0: default for training, use random dynamic chunk. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set.' +static_chunk_size: # 'chunk size for static chunk training/decoding if it is greater than 0, + # if use_dynamic_chunk is true, this parameter will be ignored' +num_decoding_left_chunks: # 'number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. + # >=0: use num_decoding_left_chunks + # <0: use all left chunks' + +# infer for Ascend310 +infer_model_path_1: "" +infer_model_path_2: "" +infer_data_path: "" +config_name: "" +label_file: "/path/test/text" # Used to calculate CER. diff --git a/examples/aishell/run.sh b/examples/aishell/run.sh index ce6e31ebf343b0dba6faad63962ff662131ef8fb..4e389dece713096994562966d19460d7e7b530f8 100644 --- a/examples/aishell/run.sh +++ b/examples/aishell/run.sh @@ -37,6 +37,7 @@ train_set=train # Optional train_config # 1. config/asr_transformer.yaml # 2. config/asr_conformer.yaml +# 3. config/asr_unified_conformer.yaml train_config=config/asr_conformer.yaml training_with_eval=True # ckpt file will be saved at $exp/${net_name}/${exp_name}/model diff --git a/flyspeech/decode/predict_net.py b/flyspeech/decode/predict_net.py index bee15ea9d324ee8c970c24a36cdd6fda9b7e6927..3fb16681225b4adb9253bd773acab336f6ab66dc 100644 --- a/flyspeech/decode/predict_net.py +++ b/flyspeech/decode/predict_net.py @@ -149,10 +149,12 @@ class CTCGreedySearch(nn.Cell): pretrained_model (nn.Cell): speech pre-trained models, like wav2vec 2.0. """ - def __init__(self, backbone, pretrained_model=False): + def __init__(self, backbone, pretrained_model=False, simulate_streaming=False, decoding_chunk_size=4): super(CTCGreedySearch, self).__init__() self.backbone = backbone self.pretrained_model = pretrained_model + self.simulate_streaming = simulate_streaming + self.decoding_chunk_size = decoding_chunk_size self.encoder = self.backbone.acc_net.encoder self.topk = ops.TopK() self.cast = ops.Cast() @@ -175,8 +177,11 @@ class CTCGreedySearch(nn.Cell): xs_pad, xs_masks, xs_lengths) if self.backbone.acc_net.feature_post_proj: xs_pad = self.backbone.acc_net.feature_post_proj(xs_pad) - xs_masks = xs_masks[:, :, :-2:2][:, :, :-2:2] - encoder_out, encoder_mask = self.encoder(xs_pad, xs_masks, xs_masks) + if self.simulate_streaming: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(xs_pad, xs_masks, self.decoding_chunk_size) + else: + xs_masks = xs_masks[:, :, :-2:2][:, :, :-2:2] + encoder_out, encoder_mask = self.encoder(xs_pad, xs_masks, xs_masks) ctc_probs = self.backbone.acc_net.ctc.compute_log_softmax_out(encoder_out) # (B, T, 1) topk_prob, topk_index = self.topk(ctc_probs, 1) # (B, T) diff --git a/flyspeech/transformer/attention.py b/flyspeech/transformer/attention.py index e2b660d7b7f1ee45cb28db89aec1cf3137e6f2c7..b5309951cff55377598d5c77e8784e913a74735d 100644 --- a/flyspeech/transformer/attention.py +++ b/flyspeech/transformer/attention.py @@ -63,6 +63,8 @@ class MultiHeadedAttention(nn.Cell): self.mul = ops.Mul() self.add = ops.Add() self.get_dtype = ops.DType() + self.concat = ops.Concat(axis=2) + self.concat_1 = ops.Concat(axis=-1) def forward_qkv(self, query: mindspore.Tensor, key: mindspore.Tensor, value: mindspore.Tensor) -> Tuple[mindspore.Tensor, mindspore.Tensor, mindspore.Tensor]: @@ -129,7 +131,9 @@ class MultiHeadedAttention(nn.Cell): key: mindspore.Tensor, value: mindspore.Tensor, mask: Optional[mindspore.Tensor], - pos_emb: Optional[mindspore.Tensor] = None) -> mindspore.Tensor: # pylint: disable=W0613 + pos_emb: Optional[mindspore.Tensor] = None, # pylint: disable=W0613 + cache: mindspore.Tensor = ops.Zeros()((0, 0, 0, 0), mindspore.float32))\ + -> mindspore.Tensor: """Compute scaled dot product attention. Args: @@ -150,13 +154,23 @@ class MultiHeadedAttention(nn.Cell): Wenet. pos_emb (mindspore.Tensor): Positional embedding tensor (#batch, time2, size). + cache (mindspore.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` Returns: mindspore.Tensor: Output tensor (#batch, time1, d_model). """ q, k, v = self.forward_qkv(query, key, value) + if cache.shape[2] > 0: + index = cache.shape[-1]//2 + key_cache = cache[:, :, :, 0:index] + value_cache = cache[:, :, :, index:] + k = self.concat((key_cache, k)) + v = self.concat((value_cache, v)) + new_cache = self.concat_1((k, v)) scores = self.matmul(q * self.scores_mul, k.transpose(0, 1, 3, 2) * self.scores_mul) - return self.forward_attention(v, scores, mask) + return self.forward_attention(v, scores, mask), new_cache class RelPositionMultiHeadedAttention(MultiHeadedAttention): @@ -186,7 +200,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): key: mindspore.Tensor, value: mindspore.Tensor, mask: Optional[mindspore.Tensor], - pos_emb: Optional[mindspore.Tensor] = None): + pos_emb: Optional[mindspore.Tensor] = None, + cache: mindspore.Tensor = ops.Zeros()((0, 0, 0, 0), mindspore.float32)): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. @@ -198,6 +213,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): (#batch, time1, time2). pos_emb (mindspore.Tensor): Positional embedding tensor (#batch, time2, size). + cache (mindspore.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` Returns: mindspore.Tensor: Output tensor (#batch, time1, d_model). """ @@ -207,6 +225,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): q, k, v = self.forward_qkv(query, key, value) q = q.transpose(0, 2, 1, 3) # (batch, time1, head, d_k) + if cache.shape[2] > 0: + index = cache.shape[-1]//2 + key_cache = cache[:, :, :, 0:index] + value_cache = cache[:, :, :, index:] + k = self.concat((key_cache, k)) + v = self.concat((value_cache, v)) + new_cache = self.concat_1((k, v)) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose(0, 2, 1, 3) # (batch, head, time1, d_k) @@ -230,4 +256,4 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): scores = matrix_ac + matrix_bd scores = self.mul(scores, self.scores_mul) - return self.forward_attention(v, scores, mask) + return self.forward_attention(v, scores, mask), new_cache diff --git a/flyspeech/transformer/convolution.py b/flyspeech/transformer/convolution.py index df33f7c5bb8945c7abdb25298bf5f4a2d9b597a1..c296cbfa48e8787a100dd2236ea90f954e75d46d 100644 --- a/flyspeech/transformer/convolution.py +++ b/flyspeech/transformer/convolution.py @@ -34,6 +34,7 @@ class ConvolutionModule(nn.Cell): channels (int): The number of channels of conv layers. kernel_size (int): Kernel size of conv layers. activation (nn.Cell): Activation function of CNN module. + causal (bool): whether to use causal convolution or not. norm (str): Normalize type of CNN module, batch norm or layer norm. glu_dim (int): Dimension of GLU activation function. bias (bool): Whether use bias for CNN layer. @@ -44,6 +45,7 @@ class ConvolutionModule(nn.Cell): channels: int, kernel_size: int, activation: nn.Cell, + causal: bool = False, norm: str = 'batch_norm', glu_dim: int = 1, bias: bool = True, @@ -57,11 +59,26 @@ class ConvolutionModule(nn.Cell): has_bias=bias, pad_mode='valid', enable_mask_padding_feature=False).to_float(compute_type) + + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + if self.lorder > 0: + self.pad = nn.Pad(paddings=((0, 0), (0, 0), (self.lorder, 0)), mode="CONSTANT") self.depthwise_conv = Conv1d(channels, channels, kernel_size, stride=1, - padding=(kernel_size-1) // 2, + padding=padding, group=channels, has_bias=bias, pad_mode='pad', @@ -87,15 +104,22 @@ class ConvolutionModule(nn.Cell): self.glu = GLU(dim=glu_dim) self.reshape = ops.Reshape() self.cast = ops.Cast() + self.zeros = ops.Zeros() + self.concat = ops.Concat(axis=2) def construct(self, x: mindspore.Tensor, - mask: mindspore.Tensor = None) -> Tuple[mindspore.Tensor, mindspore.Tensor]: + mask: mindspore.Tensor = None, + cache: mindspore.Tensor = ops.Zeros()((0, 0, 0), mindspore.float32))\ + -> Tuple[mindspore.Tensor, mindspore.Tensor]: """Compute convolution module. Args: x (mask.Tensor): Input tensor (#batch, time, channels). - mask (mindspore.Tensor): Mask (#batch, 1, time) + mask (mindspore.Tensor): Mask (#batch, 1, time). + cache (mindspore.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. Returns: mindspore.Tensor: Output tensor (#batch, time, channels). """ @@ -105,6 +129,18 @@ class ConvolutionModule(nn.Cell): if mask is not None: x = x * mask + if self.lorder > 0: + if cache.shape[2] == 0: + x = self.pad(x) + else: + assert cache.shape[0] == x.shape[0] + assert cache.shape[1] == x.shape[1] + x = self.concat((cache, x)) + assert x.shape[2] > self.lorder + new_cache = x[:, :, -self.lorder:] + else: + new_cache = self.zeros((0, 0, 0), x.dtype) + # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2 * channels, time) x = self.glu(x) # (batch, channels, time) @@ -134,4 +170,4 @@ class ConvolutionModule(nn.Cell): if mask is not None: x = x * mask - return x.transpose(0, 2, 1) + return x.transpose(0, 2, 1), new_cache diff --git a/flyspeech/transformer/decoder_layer.py b/flyspeech/transformer/decoder_layer.py index cff513f134312373f3b45408072b3506e7e5e8c0..e7ece97779d8fa467f563f3fe1614503794c3be6 100644 --- a/flyspeech/transformer/decoder_layer.py +++ b/flyspeech/transformer/decoder_layer.py @@ -109,10 +109,10 @@ class DecoderLayer(nn.Cell): tgt_q = tgt tgt_q_mask = tgt_mask if self.concat_after: - tgt_concat = self.cat1((tgt_q, self.cast(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask), tgt_q.dtype))) + tgt_concat = self.cat1((tgt_q, self.cast(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0], tgt_q.dtype))) x = residual + self.concat_linear1(tgt_concat) else: - x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) if not self.normalize_before: x = self.norm1(x) @@ -122,10 +122,10 @@ class DecoderLayer(nn.Cell): x = self.norm2(x) if self.concat_after: - x_concat = self.cat1((x, self.cast(self.src_attn(x, memory, memory, memory_mask), x.dtype))) + x_concat = self.cat1((x, self.cast(self.src_attn(x, memory, memory, memory_mask)[0], x.dtype))) x = residual + self.concat_linear2(x_concat) else: - x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) + x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)[0]) if not self.normalize_before: x = self.norm2(x) diff --git a/flyspeech/transformer/embedding.py b/flyspeech/transformer/embedding.py index 56302ea068759e1d5f9c8fa0d7b381a364baecf5..caae4c48672931d54414d30fb01cf37bfcbf461c 100644 --- a/flyspeech/transformer/embedding.py +++ b/flyspeech/transformer/embedding.py @@ -60,10 +60,10 @@ class PositionalEncoding(nn.Cell): """Compute positional encoding. Args: - x (minspore.Tensor): Input tensor (batch, time, `*`). + x (mindspore.Tensor): Input tensor (batch, time, `*`). Returns: - minspore.Tensor: Encoded tensor (batch, time, `*`). - minspore.Tensor: Positional embedding tensor (1, time, `*`). + mindspore.Tensor: Encoded tensor (batch, time, `*`). + mindspore.Tensor: Positional embedding tensor (1, time, `*`). """ pos_emb = self.pe[:, offset:offset + x.shape[1]] x = x * self.xscale + pos_emb @@ -87,10 +87,10 @@ class RelPositionalEncoding(PositionalEncoding): """Compute positional encoding. Args: - x (minspore.Tensor): Input tensor (batch, time, `*`). + x (mindspore.Tensor): Input tensor (batch, time, `*`). Returns: - minspore.Tensor: Encoded tensor (batch, time, `*`). - minspore.Tensor: Positional embedding tensor (1, time, `*`). + mindspore.Tensor: Encoded tensor (batch, time, `*`). + mindspore.Tensor: Positional embedding tensor (1, time, `*`). """ x = x * self.xscale pos_emb = self.pe[:, offset:offset + x.shape[1]] @@ -130,10 +130,10 @@ class ConvPositionalEncoding(nn.Cell): """Compute positional encoding. Args: - x (minspore.Tensor): Input tensor (batch, time, `*`). + x (mindspore.Tensor): Input tensor (batch, time, `*`). Returns: - minspore.Tensor: Encoded tensor (batch, time, `*`). - minspore.Tensor: Positional embedding tensor (1, time, `*`). + mindspore.Tensor: Encoded tensor (batch, time, `*`). + mindspore.Tensor: Positional embedding tensor (1, time, `*`). """ b, t, c = x.shape # B T C x_pos = x.transpose(0, 2, 1) # B C T @@ -162,10 +162,10 @@ class NoPositionalEncoding(nn.Cell): def construct(self, x: mindspore.Tensor, offset: int = 0): # pylint: disable=W0613 """Just return zero vector for interface compatibility Args: - x (minspore.Tensor): Input tensor (batch, time, `*`). + x (mindspore.Tensor): Input tensor (batch, time, `*`). Returns: - minspore.Tensor: Encoded tensor (batch, time, `*`). - minspore.Tensor: Positional embedding tensor (1, time, `*`). + mindspore.Tensor: Encoded tensor (batch, time, `*`). + mindspore.Tensor: Positional embedding tensor (1, time, `*`). """ pos_emb = self.zeros((1, x.shape[1], self.d_model), mstype.float32) return self.dropout(x), pos_emb diff --git a/flyspeech/transformer/encoder.py b/flyspeech/transformer/encoder.py index 1d5bd041dc64999f82da2912d426975423b2430d..6ff5bc1e37b84079be6e4286d881b433bd2dff0a 100644 --- a/flyspeech/transformer/encoder.py +++ b/flyspeech/transformer/encoder.py @@ -22,6 +22,7 @@ from typing import Tuple import mindspore import mindspore.common.dtype as mstype import mindspore.nn as nn +import mindspore.ops as ops from flyspeech.layers.layernorm import LayerNorm from flyspeech.transformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention @@ -89,6 +90,10 @@ class BaseEncoder(nn.Cell): self.feature_norm = feature_norm self.global_cmvn = global_cmvn + self.expanddims = ops.ExpandDims() + self.concat = ops.Concat(axis=0) + self.concat_1 = ops.Concat(axis=1) + self.concat_2 = ops.Concat(axis=-1) def output_size(self) -> int: return self._output_size @@ -114,7 +119,7 @@ class BaseEncoder(nn.Cell): # masks is subsampled to (B, 1, T/subsample_rate) xs, pos_emb = self.embed(xs) for layer in self.encoders: - xs, xs_chunk_masks = layer(xs, xs_chunk_masks, pos_emb, masks) + xs, xs_chunk_masks, _, _ = layer(xs, xs_chunk_masks, pos_emb, masks) if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -122,6 +127,152 @@ class BaseEncoder(nn.Cell): # for cross attention with decoder later return xs, masks + def forward_chunk( + self, + xs: mindspore.Tensor, + chunk_masks: mindspore.Tensor, + masks: mindspore.Tensor, + offset: int, + required_cache_size: int, + att_cache: mindspore.Tensor = ops.Zeros()((0, 0, 0, 0), mindspore.float32), + cnn_cache: mindspore.Tensor = ops.Zeros()((0, 0, 0, 0), mindspore.float32), + ) -> Tuple[mindspore.Tensor, mindspore.Tensor, mindspore.Tensor]: + """ Forward just one chunk + + Args: + xs (mindspore.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + chunk_masks (mindspore.Tensor): Mask tensor for the input (#batch, 1, time). + masks (mindspore.Tensor): masks for the input xs () + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + computation + >=0: actual cache size + <0: means all history cache is required + att_cache (mindspore.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (mindspore.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + mindspore.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + mindspore.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + mindspore.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.shape[0] == 1 + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb = self.embed(xs, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.shape[0], att_cache.shape[2] + chunk_size = xs.shape[1] + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding( + offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, chunk_masks, pos_emb, masks, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.shape[0] > 0 else cnn_cache + ) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(self.expanddims(new_cnn_cache, 0)) + if self.normalize_before: + xs = self.after_norm(xs) + + r_att_cache = self.concat(r_att_cache) + r_cnn_cache = self.concat(r_cnn_cache) + return xs, r_att_cache, r_cnn_cache + + def forward_chunk_by_chunk( + self, + xs: mindspore.Tensor, + masks: mindspore.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[mindspore.Tensor, mindspore.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not preferred. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.shape[1] + att_cache = ops.Zeros()((0, 0, 0, 0), mindspore.float32) + cnn_cache = ops.Zeros()((0, 0, 0, 0), mindspore.float32) + chunk_masks = ops.Zeros()((0, 0, 0, 0), mindspore.float32) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + masks_pad = masks[:, :, cur:end] + masks_pad = masks_pad[:, :, :-2:2][:, :, :-2:2] + if cur == 0: + chunk_masks = masks_pad + else: + chunk_masks = self.concat_2((chunk_masks, masks_pad)) + (y, att_cache, cnn_cache) = self.forward_chunk( + chunk_xs, chunk_masks, masks_pad, offset, required_cache_size, att_cache, cnn_cache) + outputs.append(y) + offset += y.shape[1] + ys = self.concat_1(outputs) + masks = ops.Ones()((1, 1, ys.shape[1]), mindspore.bool_) + return ys, masks + class TransformerEncoder(BaseEncoder): """Transformer encoder module. @@ -256,6 +407,7 @@ class ConformerEncoder(BaseEncoder): False: x -> x + att(x) activation_type (str): type of activation type. cnn_module_kernel (int): kernel size for CNN module + causal (bool): whether to use causal convolution or not. cnn_module_norm (str): normalize type for CNN module, batch norm or layer norm. compute_type (dtype): whether to use mix precision training. """ @@ -276,6 +428,7 @@ class ConformerEncoder(BaseEncoder): concat_after: bool = False, activation_type: str = 'swish', cnn_module_kernel: int = 15, + causal: bool = False, cnn_module_norm: str = 'batch_norm', global_cmvn: mindspore.nn.Cell = None, compute_type=mstype.float32): @@ -310,7 +463,8 @@ class ConformerEncoder(BaseEncoder): # convolution module definition convolution_layer = ConvolutionModule - convolution_layer_args = (output_size, cnn_module_kernel, activation, cnn_module_norm, 1, True, compute_type) + convolution_layer_args = (output_size, cnn_module_kernel, activation, causal, + cnn_module_norm, 1, True, compute_type) self.encoders = nn.CellList([ ConformerEncoderLayer( diff --git a/flyspeech/transformer/encoder_layer.py b/flyspeech/transformer/encoder_layer.py index 4977508ddb30fbe2d7f025f817a0188705d15ae5..91fe4fe1d5a209b8c453ac6a36b54065cdce4afe 100644 --- a/flyspeech/transformer/encoder_layer.py +++ b/flyspeech/transformer/encoder_layer.py @@ -81,13 +81,13 @@ class TransformerEncoderLayer(nn.Cell): """Compute encoded features. Args: - x (minspore.Tensor): Input tensor (#batch, time, size). - mask (minspore.Tensor): Mask tensor for the input (#batch, 1, time). - output_cache (minspore.Tensor): Cache tensor of the output + x (mindspore.Tensor): Input tensor (#batch, time, size). + mask (mindspore.Tensor): Mask tensor for the input (#batch, 1, time). + output_cache (mindspore.Tensor): Cache tensor of the output (#batch, time2, size), time2 < time in x. Returns: - minspore.Tensor: Output tensor (#batch, time, size). - minspore.Tensor: Mask tensor (#batch, time). + mindspore.Tensor: Output tensor (#batch, time, size). + mindspore.Tensor: Mask tensor (#batch, time). """ # Multi-headed self-attention module residual = x @@ -183,6 +183,7 @@ class ConformerEncoderLayer(nn.Cell): self.cast = ops.Cast() self.get_dtype = ops.DType() self.compute_type = compute_type + self.zeros = ops.Zeros() def construct( self, @@ -190,21 +191,27 @@ class ConformerEncoderLayer(nn.Cell): mask: mindspore.Tensor, pos_emb: mindspore.Tensor, mask_pad: mindspore.Tensor, - output_cache: Optional[mindspore.Tensor] = None + att_cache: mindspore.Tensor = ops.Zeros()((0, 0, 0, 0), mindspore.float32), + cnn_cache: mindspore.Tensor = ops.Zeros()((0, 0, 0, 0), mindspore.float32), ) -> Tuple[mindspore.Tensor, mindspore.Tensor, mindspore.Tensor]: """Compute encoded features. Args: - x (minspore.Tensor): (#batch, time, size) - mask (minspore.Tensor): Mask tensor for the input (#batch, 1, time). - pos_emb (minspore.Tensor): positional encoding, must not be None + x (mindspore.Tensor): (#batch, time, size) + mask (mindspore.Tensor): Mask tensor for the input (#batch, 1, time). + pos_emb (mindspore.Tensor): positional encoding, must not be None for ConformerEncoderLayer. mask_pad (mindspore.Tensor): mask for input tensor. - output_cache (minspore.Tensor): Cache tensor of the output - (#batch, time2, size), time2 < time in x. + att_cache (mindspore.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (mindspore.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) Returns: - minspore.Tensor: Output tensor (#batch, time, size). - minspore.Tensor: Mask tensor (#batch, time). + mindspore.Tensor: Output tensor (#batch, time, size). + mindspore.Tensor: Mask tensor (#batch, time). + mindspore.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + mindspore.Tensor: cnn_cahce tensor (#batch, size, cache_t2). """ # Macaron-Net Feedforward module residual = x @@ -219,16 +226,7 @@ class ConformerEncoderLayer(nn.Cell): if self.normalize_before: x = self.norm_mha(x) - if output_cache is None: - x_q = x - else: - # TODO: wait to be reviewed - chunk = x.shape[1] - output_cache.shape[1] - x_q = x[:, -chunk:, :] - residual = residual[:, -chunk:, :] - mask = mask[:, -chunk:, :] - - x_att = self.self_attn(x_q, x, x, mask, pos_emb) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) # TODO: need to be reviewed if self.concat_after: @@ -243,7 +241,8 @@ class ConformerEncoderLayer(nn.Cell): residual = x if self.normalize_before: x = self.norm_conv(x) - x = residual + self.dropout(self.conv_module(x, mask_pad)) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) if not self.normalize_before: x = self.norm_conv(x) @@ -258,7 +257,4 @@ class ConformerEncoderLayer(nn.Cell): # Final normalization x = self.norm_final(x) - if output_cache is not None: - x = self.cat_1([output_cache, x], dim=1) - - return x, mask + return x, mask, new_att_cache, new_cnn_cache diff --git a/flyspeech/transformer/subsampling.py b/flyspeech/transformer/subsampling.py index c0a7268d7d6df54a1b8bc548db32eb40bb7416fb..7b55694cf77aefb470cd61c5bf08c48c03caf43e 100644 --- a/flyspeech/transformer/subsampling.py +++ b/flyspeech/transformer/subsampling.py @@ -69,15 +69,15 @@ class Conv2dSubsampling4(BaseSubsampling): """Subsample x. Args: - x (minspore.Tensor): Input tensor (#batch, time, idim). - x_mask (minspore.Tensor): Input mask (#batch, 1, time). + x (mindspore.Tensor): Input tensor (#batch, time, idim). + x_mask (mindspore.Tensor): Input mask (#batch, 1, time). Returns: - minspore.Tensor: Subsampled tensor (#batch, time', odim), + mindspore.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 4. - minspore.Tensor: Subsampled mask (#batch, 1, time'), + mindspore.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 4. - minspore.Tensor: positional encoding + mindspore.Tensor: positional encoding """ x = self.expanddims(x, 1) # (b, c=1, t, f) x = self.conv(x) diff --git a/flyspeech/utils/mask.py b/flyspeech/utils/mask.py index caad20ad840ed56907c916f872edf429149e402e..385a9fc1f9f7e64ca867729e6fe77754f6c47959 100644 --- a/flyspeech/utils/mask.py +++ b/flyspeech/utils/mask.py @@ -264,7 +264,7 @@ def add_optional_chunk_mask(xs_len, masks, use_dynamic_chunk, use_dynamic_left_c if chunk_size > max_len // 2: chunk_size = max_len else: - chunk_size = chunk_size%25 + 1 + chunk_size = chunk_size % 25 + 1 if use_dynamic_left_chunk: max_left_chunks = (max_len-1) // chunk_size num_left_chunks = np.random.randint(0, max_left_chunks, (1,)).tolist()[0] diff --git a/predict.py b/predict.py index b1565d5913b5f11420397563bde29f9377914536..8509fa3e12cf5f837134a9887520c3c500b0e521 100644 --- a/predict.py +++ b/predict.py @@ -75,9 +75,10 @@ def main(): load_param_into_net(network, param_dict) logger.info('Successfully loading the asr model: %s', decode_ckpt) network.set_train(False) - + simulate_streaming = config.get("simulate_streaming", False) + chunk_size = config.collate_conf.decoding_chunk_size if config.decode_mode == 'ctc_greedy_search': - model = Model(CTCGreedySearch(network)) + model = Model(CTCGreedySearch(network, simulate_streaming=simulate_streaming, decoding_chunk_size=chunk_size)) elif config.decode_mode == 'attention' and config.full_graph: model = Model(Attention(network, config.beam_size, eos)) elif config.decode_mode == 'attention' and not config.full_graph: