• WeNet更新:喜马拉雅团队在 WeNet 中支持 Squeezeformer


    WeNet在正式发布两年的时间里,成为非常热门的ASR生产工具,其面向生产的属性更是深受工业界的好评。近期,喜马拉雅团队在WeNet中支持了Squeezeformer的相关工作。本文由喜马拉雅珠峰智能实验室撰写,介绍了Squeezeformer论文的复现细节,包括训练方案、流式推理以及实验结果。

    喜马拉雅珠峰智能实验室:聚焦音视频以及智能语音技术,先后打造了语音合成(TTS)、语音识别(ASR)、智能审核、语音唤醒、智能音效、降噪、智能配乐、虚拟人讲书等产品和能力。通过行业领先的TTS技术,喜马拉雅用AIGC(AI生成内容)引领长音频行业的内容生产变革,让内容生产提效。推出的”AI开放平台“和 辅助创作者生产工具“喜韵音坊“ 为B端和C端的用户和创作者提供服务。与此同时,音视频实验室通过AI文稿、图像生成等多项音视频技术,进一步提升喜马拉雅用户的内容消费体验。“单田芳声音重现”等账号下上线的运用单田芳AI合成音所制作的专辑数量已经有100多张,总播放量超过1亿。

    论文介绍

    由伯克利大学和谷歌合作的Squeezeformer[1]旨在推进下一代的语音识别主干网络,并达到了同等参数量下性能优于Conformer的结果,文章已经被NeurIPS 2022收录,本文尝试对其进行复现,实现中我们参考了Code[2]。

    Squeezeformer相对于Conformer,主要包含4个改进点:

    • Temporal U-Net 结构: 作者通过实验发现ASR训练过程中,中间层的帧embedding相似度很高,因此提出在时间维度上对中间层的帧数进行压缩,在最后一层恢复的方式。

    • MFCF block 结构: 作者推荐采用self-attention + ffn + conv module + ffn (MFCF)的组合替代标准Conformer中ffn + self-attention + conv module + ffn(这里的1/2也被取消)。

    • 微观架构改动: GLU被替换为Swish;同时作者推荐adaptive scale + PostLN的方式,代替单纯的PreLN或PostLN;subsampling中部分conv被替换为depthwise conv。

    • Scale up: 由于U-Net的结构,相同参数量的squeezeformer的FLOPs比Conformer更低,因此作者采用了scale up的方式给出了FLOPs与Conformer相同时的对比效果。

    算法实现

    (1) 下采样部分其中一个pointwise卷积被替换为depthwise卷积。

    1. class DepthwiseConv2dSubsampling4(BaseSubsampling):
    2.     """Depthwise Convolutional 2D subsampling (to 1/4 length).
    3.         Args:
    4.             idim (int): Input dimension.
    5.             odim (int): Output dimension.
    6.             pos_enc_class (nn.Module): position encoding class.
    7.             dw_stride (int): Whether do depthwise convolution.
    8.             input_size (int): filter bank dimension.
    9.         """
    10.     def __init__(
    11.             self, idim: int, odim: int,
    12.             pos_enc_class: torch.nn.Module,
    13.             dw_stride: bool = False,
    14.             input_size: int = 80,
    15.             input_dropout_rate: float = 0.1,
    16.             init_weights: bool = True
    17.     ):
    18.         super(DepthwiseConv2dSubsampling4self).__init__()
    19.         self.idim = idim
    20.         self.odim = odim
    21.         self.pw_conv = nn.Conv2d(
    22.             in_channels=idim, out_channels=odim, kernel_size=3, stride=2)
    23.         self.act1 = nn.ReLU()
    24.         self.dw_conv = nn.Conv2d(
    25.             in_channels=odim, out_channels=odim, kernel_size=3, stride=2,
    26.             groups=odim if dw_stride else 1
    27.         )
    28.         self.act2 = nn.ReLU()
    29.         self.pos_enc = pos_enc_class
    30.         self.input_proj = nn.Sequential(
    31.             nn.Linear(
    32.                 odim * (((input_size - 1// 2 - 1// 2), odim),
    33.             nn.Dropout(p=input_dropout_rate),
    34.         )
    35.         if init_weights:
    36.             linear_max = (odim * input_size / 4** -0.5
    37.             torch.nn.init.uniform_(
    38.                 self.input_proj.state_dict()['0.weight'], -linear_max, linear_max)
    39.             torch.nn.init.uniform_(
    40.                 self.input_proj.state_dict()['0.bias'], -linear_max, linear_max)
    41.         self.subsampling_rate = 4
    42.         # 6 = (3 - 1* 1 + (3 - 1* 2
    43.         self.right_context = 6

    (2) 作者通过对比相邻帧之间的Cosine Similarity发现,在Conformer模型中间层有着比较大的信息冗余,尤其是在序号更大的block。
    因此采用U-Net的结构替换原始的Conformer,核心部分是TimeReductionLayer。
    TimeReductionLayer 将单位帧对应时长由40ms变为80ms,即在时间维度上变为1/2。我们这里提供了1D和2D版本的TimeReductionLayer。

    1. class TimeReductionLayer1D(nn.Module):
    2.     def __init__(self, channel: int, out_dim: int,
    3.                  kernel_size: int = 5, stride: int = 2):
    4.         super(TimeReductionLayer1D, self).__init__()
    5.         self.channel = channel
    6.         self.out_dim = out_dim
    7.         self.kernel_size = kernel_size
    8.         self.stride = stride
    9.         self.padding = max(0self.kernel_size - self.stride)
    10.         self.dw_conv = nn.Conv1d(
    11.             in_channels=channel,
    12.             out_channels=channel,
    13.             kernel_size=kernel_size,
    14.             stride=stride,
    15.             padding=self.padding,
    16.             groups=channel,
    17.         )
    18.         self.pw_conv = nn.Conv1d(
    19.             in_channels=channel, out_channels=out_dim,
    20.             kernel_size=1, stride=1, padding=0, groups=1,
    21.         )
    22.         self.init_weights()
    23.     def init_weights(self):
    24.         dw_max = self.kernel_size ** -0.5
    25.         pw_max = self.channel ** -0.5
    26.         torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max)
    27.         torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max)
    28.         torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max)
    29.         torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max)
    30.     def forward(self, xs, xs_lens: torch.Tensor,
    31.                 mask: torch.Tensor = torch.ones((000), dtype=torch.bool),
    32.                 mask_pad: torch.Tensor = torch.ones((000), dtype=torch.bool),
    33.                 ):
    34.         xs = xs.transpose(12)  # [B, C, T]
    35.         xs = xs.masked_fill(mask_pad.eq(0), 0.0)
    36.         xs = self.dw_conv(xs)
    37.         xs = self.pw_conv(xs)
    38.         xs = xs.transpose(12)  # [B, T, C]
    39.         B, T, D = xs.size()
    40.         mask = mask[:, ::self.stride, ::self.stride]
    41.         mask_pad = mask_pad[:, :, ::self.stride]
    42.         L = mask_pad.size(-1)
    43.         # For JIT exporting, we remove F.pad operator.
    44.         if L - T < 0:
    45.             xs = xs[:, :L - T, :].contiguous()
    46.         else:
    47.             dummy_pad = torch.zeros(B, L - T, D, device=xs.device)
    48.             xs = torch.cat([xs, dummy_pad], dim=1)
    49.         xs_lens = torch.div(xs_lens + 12, rounding_mode='trunc')
    50.         return xs, xs_lens, mask, mask_pad

    (3) Recover部分实现,在时间维度上进行复制、映射和recover_tensor叠加。

    1.     # recover output length for ctc decode
    2.     xs = torch.repeat_interleave(xs, repeats=2, dim=1)
    3.     xs = self.time_recover_layer(xs)
    4.     recoverd_t = recover_tensor.size(1)
    5.     xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()

    (4) 替换FMCF为MFCF结构,这里默认采用PostLN,也兼容了PreLN。

    1.     def forward(
    2.             self,
    3.             x: torch.Tensor,
    4.             mask: torch.Tensor,
    5.             pos_emb: torch.Tensor,
    6.             mask_pad: torch.Tensor = torch.ones((000), dtype=torch.bool),
    7.             att_cache: torch.Tensor = torch.zeros((0000)),
    8.             cnn_cache: torch.Tensor = torch.zeros((0000)),
    9.         ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    10.         # self attention module
    11.         residual = x
    12.         if self.normalize_before:
    13.             x = self.layer_norm1(x)
    14.         x_att, new_att_cache = self.self_attn(
    15.             x, x, x, mask, pos_emb, att_cache)
    16.         if self.concat_after:
    17.             x_concat = torch.cat((x, x_att), dim=-1)
    18.             x = residual + self.concat_linear(x_concat)
    19.         else:
    20.             x = residual + self.dropout(x_att)
    21.         if not self.normalize_before:
    22.             x = self.layer_norm1(x)
    23.         # ffn module
    24.         residual = x
    25.         if self.normalize_before:
    26.             x = self.layer_norm2(x)
    27.         x = self.ffn1(x)
    28.         x = residual + self.dropout(x)
    29.         if not self.normalize_before:
    30.             x = self.layer_norm2(x)
    31.         # conv module
    32.         new_cnn_cache = torch.zeros((000), dtype=x.dtype, device=x.device)
    33.         residual = x
    34.         if self.normalize_before:
    35.             x = self.layer_norm3(x)
    36.         x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
    37.         x = residual + self.dropout(x)
    38.         if not self.normalize_before:
    39.             x = self.layer_norm3(x)
    40.         # ffn module
    41.         residual = x
    42.         if self.normalize_before:
    43.             x = self.layer_norm4(x)
    44.         x = self.ffn2(x)
    45.         x = residual + self.dropout(x)
    46.         if not self.normalize_before:
    47.             x = self.layer_norm4(x)
    48.         return x, mask, new_att_cache, new_cnn_cache

    (5) 如图所示,PreLN的部分被替换为adaptive scale。这里以FeedForward layer为例,adaptive scale指的是在输入layer之前加入一组可学习参数的仿射变换,以及合理的init weights方式。许多研究工作表明,PreLN的结构更容易稳定收敛,PostLN训练的模型效果更好,因此adaptive scale+PostLN的组合被视为两者优点的结合。

    1. class PositionwiseFeedForward(torch.nn.Module):
    2.     """Positionwise feed forward layer.
    3.     FeedForward are appied on each position of the sequence.
    4.     The output dim is same with the input dim.
    5.     Args:
    6.         idim (int): Input dimenstion.
    7.         hidden_units (int): The number of hidden units.
    8.         dropout_rate (float): Dropout rate.
    9.         activation (torch.nn.Module): Activation function
    10.     """
    11.     def __init__(self,
    12.                  idim: int,
    13.                  hidden_units: int,
    14.                  dropout_rate: float,
    15.                  activation: torch.nn.Module = torch.nn.ReLU(),
    16.                  adaptive_scale: bool = False,
    17.                  init_weights: bool = False
    18.                  ):
    19.         """Construct a PositionwiseFeedForward object."""
    20.         super(PositionwiseFeedForward, self).__init__()
    21.         self.idim = idim
    22.         self.hidden_units = hidden_units
    23.         self.w_1 = torch.nn.Linear(idim, hidden_units)
    24.         self.activation = activation
    25.         self.dropout = torch.nn.Dropout(dropout_rate)
    26.         self.w_2 = torch.nn.Linear(hidden_units, idim)
    27.         self.ada_scale = None
    28.         self.ada_bias = None
    29.         self.adaptive_scale = adaptive_scale
    30.         self.ada_scale = torch.nn.Parameter(
    31.             torch.ones([11, idim]), requires_grad=adaptive_scale)
    32.         self.ada_bias = torch.nn.Parameter(
    33.             torch.zeros([11, idim]), requires_grad=adaptive_scale)
    34.         if init_weights:
    35.             self.init_weights()
    36.     def init_weights(self):
    37.         ffn1_max = self.idim ** -0.5
    38.         ffn2_max = self.hidden_units ** -0.5
    39.         torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max)
    40.         torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max)
    41.         torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max)
    42.         torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max)
    43.     def forward(self, xs: torch.Tensor) -> torch.Tensor:
    44.         """Forward function.
    45.         Args:
    46.             xs: input tensor (B, L, D)
    47.         Returns:
    48.             output tensor, (B, L, D)
    49.         """
    50.         if self.adaptive_scale:
    51.             xs = self.ada_scale * xs + self.ada_bias
    52.         return self.w_2(self.dropout(self.activation(self.w_1(xs))))

    流式推理

    如下图所示,由于Squeezeformer在squeeze的部分与Conformer不同,为了在流式推理过程中保持接口使用方式不变,我们在Squeezeformer推理中额外采用了slice + pad的形式。缓存attention cache时,在时间维度上复制下采样倍数;计算下一个chunk时,按照下采样系数取出。

    attention cache & cnn cache 的核心代码如下:

    1.     factor = self.calculate_downsampling_factor(i)
    2.     xs, _, new_att_cache, new_cnn_cache = layer(
    3.         xs, att_mask, pos_emb,
    4.         att_cache=att_cache[i:i + 1][:, :, ::factor, :]
    5.         [:, :, :pos_emb.size(1) - xs.size(1), :] if
    6.         elayers > 0 else att_cache[:, :, ::factor, :],
    7.         cnn_cache=cnn_cache[i] if cnn_cache.size(0> 0 else cnn_cache
    8.     )
    9.     cached_att \
    10.         = new_att_cache[:, :, next_cache_start // factor:, :]
    11.     cached_cnn = new_cnn_cache.unsqueeze(0)
    12.     cached_att = cached_att.repeat_interleave(repeats=factor, dim=2)
    13.     if i == 0:
    14.         # record length for the first block as max length
    15.         max_att_len = cached_att.size(2)
    16.     r_att_cache.append(cached_att[:, :, :max_att_len, :])
    17.     r_cnn_cache.append(cached_cnn)

    另外,流式推理过程中,由于time reduce的padding导致边界处理稍有差异,使得调用forward接口效果与调用forward chunk接口稍有偏差,我们这里额外给出了一种stream reduce保持推理时的一致性。

    差异1:在forward接口中卷积会对全长进行pad,卷积计算到中间位置的可见帧为数据,而forward chunk接口会在当前chunk做pad

    1. self.padding = max(0self.kernel_size - self.stride)
    2. self.dw_conv = nn.Conv1d(
    3.     in_channels=channel,
    4.     out_channels=channel,
    5.     kernel_size=kernel_size,
    6.     stride=stride,
    7.     padding=self.padding,
    8.     groups=channel,
    9. )

    差异2:与上面类似,在L-T不为0时,forward与forward chunk会带来差异

    1. = mask_pad.size(-1)
    2. if L - T < 0:
    3.     xs = xs[:, :L - T, :].contiguous()
    4. else:
    5.     dummy_pad = torch.zeros(B, L - T, D, device=xs.device)
    6.     xs = torch.cat([xs, dummy_pad], dim=1)

    实验结果

    我们在WeNet上贡献了完整的Squeezeformer训练方案,并给出了在计算速度相对可比的情况下,不同大小模型的实验效果。

    1.在最普遍使用的Medium模型,我们给出了3种尺度的结果,
    分别是V0: 接近Conformer效果的最小模型,V1: 参数量相近的模型,以及V2: FLOPs相近的模型。

    2.Large模型我们给出了参数量相近情况下的对比结果。

    3.同时Squeezeformer也支持流式的训练和推理,在参数量相近的情况下,对比效果如下。

    Squeezeformer在Librispeech上的完整训练效果,详见LibriSpeech 实验结果[3]

    补充说明:

    • SM12-V1和U2++的参数量一致,效果差异主要来自squeeze layer的实现方式、BN同步、decoder以及流式训练方式。

    • 由于这个系列算法的在CNN结构的Norm方式采用了BN,我们通过实验发现syncbn可以带来提升,因此部分实验结构采用了syncbn的操作,这个部分后续也会在WeNet中更新。


    参考资料

    [1]Squeezeformer: https://arxiv.org/pdf/2206.00888v1.pdf

    [2]Code: https://github.com/kssteven418/Squeezeformer

    [3]README.md: https://github.com/wenet-e2e/wenet/tree/main/examples/librispeech/s0

    机器翻译,仅供参考

  • 相关阅读:
    Kubernetes客户端认证(一)—— 基于CA证书的双向认证方式
    《Effective Objective-C 2.0》读书笔记——熟悉Objective-C
    【Python】利用tkinter与图灵机器人制作智能聊天系统
    ZK和redis中是否会发生脑裂问题?
    请回答数据结构【布隆过滤器&位图】
    软件工程专业毕设题目选题推荐
    uni-app实现获取未来七天时间和星期几功能
    Compose原理-compose中是如何实现事件分法的
    基于PHP+MySQL音乐网站的设计与实现
    PLC网关用途、解决问题以及如何实现高效、稳定通信分享
  • 原文地址:https://blog.csdn.net/weixin_48827824/article/details/127887589