• RE2文本匹配实战


    引言

    今天我们来实现RE2进行文本匹配,模型实现参考了官方代码https://github.com/alibaba-edu/simple-effective-text-matching-pytorch

    模型实现

    202231008143

    RE2模型架构如上图所示。它的输入是两个文本片段,所有组件参数除了预测层和对齐层外都是共享的。上图虚线框出来的为一个Block,堆叠了N个block,文本片段之间的block内部通过对齐层进行交互。block之间通过增加的残差层进行连接。

    下面我们从底向上依次实现,实现过程中参考了官方实现。

    Embedding

    嵌入层很简单没有使用字符嵌入,就是简单的单词嵌入。

    class Embedding(nn.Module):
        def __init__(self, vocab_size: int, embedding_dim: int, dropout: float) -> None:
            super().__init__()
    
            self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
    
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, x: Tensor) -> Tensor:
            """
            Args:
                x (Tensor): (batch_size, seq_len)
    
            Returns:
                Tensor: (batch_size, seq_len, embedding_dim)
            """
            return self.dropout(self.embedding(x))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    Encoder

    GeLU

    首先实现GeLU,它是RELU的变种,后来被用到BERT中。其函数图像如下所示:

    ../_images/GELU.png

    class GeLU(nn.Module):
        def forward(self, x: Tensor) -> Tensor:
            return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
    
    • 1
    • 2
    • 3

    Linear

    重写了线性层,activations开启GeLU激活。

    class Linear(nn.Module):
        def __init__(
            self, in_features: int, out_features: int, activations: bool = True
        ) -> None:
            super().__init__()
    
            linear = nn.Linear(in_features, out_features)
            modules = [weight_norm(linear)]
            if activations:
                modules.append(GeLU())
    
            self.model = nn.Sequential(*modules)
            self.reset_parameters(activations)
    
        def reset_parameters(self, activations: bool) -> None:
            linear = self.model[0]
            nn.init.normal_(
                linear.weight,
                std=math.sqrt((2.0 if activations else 1.0) / linear.in_features),
            )
            nn.init.zeros_(linear.bias)
    
        def forward(self, x):
            return self.model(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    nn.Conv1d

    我们在比较聚合模型的实现中详细了解了torch.nn.Conv2d的实现以及CNN的一些基础概念。

    这里我们通过torch.nn.Conv1d来实现论文中的多层卷积网络,本小结来详细了解Conv1d实现。

    torch.nn.Conv1d
        in_channels: 输入的通道数,文本中为嵌入维度
        out_channels: 一个卷积核产生一个输出通道
        kernel_size: 卷积核的大小
        stride: 卷积步长,默认为1
        padding: 填充,默认为0
        bias(bool): 是否添加偏置,默认为True
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    我们以一个例子来说明它的计算过程,假设对于输入"W B G 是 冠 军",随机得到的嵌入为:

    希望今天下午S13 WBG可以战胜T1。

    import numpy as np
    import torch.nn as nn
    import torch
    
    batch_size = 1
    seq_len = 6
    embed_size = 3
    
    input_tensor = torch.rand(batch_size, seq_len, embed_size)
    print(input_tensor)
    print(input_tensor.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    tensor([[[0.9291, 0.8333, 0.5160],
             [0.0543, 0.8149, 0.5704],
             [0.7831, 0.2263, 0.9279],
             [0.0898, 0.0758, 0.4401],
             [0.4321, 0.2098, 0.6666],
             [0.6183, 0.0609, 0.2330]]])
    torch.Size([1, 6, 3])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    此时每个字符对应一个3维的嵌入向量,分别为:

    W — [0.9291, 0.8333, 0.5160]
    B — [0.0543, 0.8149, 0.5704]
    G — [0.7831, 0.2263, 0.9279]
    是 — [0.0898, 0.0758, 0.4401]
    冠 — [0.4321, 0.2098, 0.6666]
    军 — [0.6183, 0.0609, 0.2330]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    但是Conv1d需要in_channels即嵌入维度为仅在batch_size后第一个位置,由[1, 6, 3]变成[1, 3, 6]

    input_tensor = input_tensor.permute(0, 2, 1)
    # (batch_size, embed_size, seq_len)
    
    • 1
    • 2

    图示如下:

    image-20231118141540674

    文章还没发,结果被3:0了。

    然后我们定义一个一维卷积:

    input_channels = embed_size # 等于embed_size
    output_channels = 2
    kernel_size = 2 # kernel_size
    
    conv1d = nn.Conv1d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    我们可以打印出来filter权重矩阵:

    print(conv1d.weight)
    print(conv1d.weight.shape)
    
    • 1
    • 2
    Parameter containing:
    tensor([[[ 0.0025,  0.3353],
             [ 0.0620, -0.3916],
             [-0.3458, -0.0610]],
    
            [[-0.1731, -0.0787],
             [-0.0419, -0.2555],
             [-0.1429,  0.1656]]], requires_grad=True)
    torch.Size([2, 3, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    filter 权重的大小为 (2,3,2) shape[0]=2是filter个数;shape[1]=3是输入嵌入大小;shape[2]=2是filter大小。

    默认是添加了偏置,一个filter一个偏置:

    Parameter containing:
    tensor([ 0.3760, -0.2881], requires_grad=True)
    torch.Size([2])
    
    • 1
    • 2
    • 3

    我们这里有两个filter,所以有两个偏置。因为这里kernel_size=2,且步长stride=1,所以一个filter是如下的方式框住了两个字符嵌入,并且每次向右移动一格:

    image-20231118142337624

    此时第一个filter的卷积操作计算为:

    sum([[0.9291, 0.0543],           [[0.0025,  0.3353],
     	[0.8333, 0.8149],     *       [0.0620, -0.3916],      +    0.3760(bias)
     	[0.5160, 0.5704]]             [-0.3458, -0.0610])
    
    • 1
    • 2
    • 3

    第一个filter权重和这两个嵌入进行逐位置乘法产生一个标量(sum),最后加上第一个filter的偏置。

    通过代码实现为:

    # 开始计算卷积
    # 前两个嵌入与卷积核权重逐元素乘法
    result = input_tensor[:,:,:2]*conv1d.weight 
    print(result)
    # 结果求和再加上偏置
    print(torch.sum(result[0]) + conv1d.bias[0])
    print(torch.sum(result[1]) + conv1d.bias[1])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    tensor([[[ 0.0024,  0.0182],
             [ 0.0517, -0.3191],
             [-0.1784, -0.0348]],
    
            [[-0.1608, -0.0043],
             [-0.0349, -0.2082],
             [-0.0737,  0.0944]]], grad_fn=)
             
    tensor(-0.0841, grad_fn=) # 第一个filter的结果
    tensor(-0.6756, grad_fn=) # 第二个filter的结果
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    这是第一次卷积的结果,第二次卷积把红框向右移动一格,又会有一个结果。

    image-20231118143821057

    最终移动到输入的最后一个位置计算完毕:

    image-20231118143932207

    共需要计算5次,因此最终一个filter会输出5个标量,共有2个filter,批大小为1。

    如果用代码实现的话:

    output = conv1d(input_tensor)
    print(output)
    print(output.shape)
    
    • 1
    • 2
    • 3
    tensor([[[-0.0841,  0.3468,  0.0447,  0.2508,  0.3288],
             [-0.6756, -0.3790, -0.5193, -0.3470, -0.4926]]],
           grad_fn=)
    torch.Size([1, 2, 5])
    
    • 1
    • 2
    • 3
    • 4

    可以看到output的形状为[1, 2, 5],第一列的计算结果和我们上面的一致。

    shape[0]=1是批次内样本个数;``shape[1]=2是filter个数,也是想要输出的channel数;shape[2]=5`是卷积后的维度。

    这里(忽略dilation)卷积后的维度大小由卷积核大小kernel_size、步长stride、填充padding以及输入序列长度seq_len决定:
    ⌊ seq_len + 2 × padding − kernel_size stride + 1 ⌋ \left\lfloor \frac{\text{seq\_len} + 2 \times\text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor strideseq_len+2×paddingkernel_size+1

    其中 ⌊ ∗ ⌋ \lfloor * \rfloor 表示向下取整。

    我们可以代入上面的参数验证一下:
    6 + 2 × 0 − 2 1 + 1 = 6 + 0 − 2 + 1 = 5 \frac{6 + 2\times 0 - 2}{1} + 1 = 6+0-2+1=5 16+2×02+1=6+02+1=5
    结果是5和输出一致。

    Conv1d

    下面实现RE2的多层卷积网络,首先是一个改写的Conv1d,用weight_norm进行权重归一化,采用GeLU激活函数。

    class Conv1d(nn.Module):
        def __init__(
            self, in_channels: int, out_channels: int, kernel_sizes: list[int]
        ) -> None:
            """
    
            Args:
                in_channels (int): the embedding_dim
                out_channels (int): number of filters
                kernel_sizes (list[int]): the size of kernel
            """
            super().__init__()
    
            out_channels = out_channels // len(kernel_sizes)
    
            convs = []
            # L_in is seq_len, L_out is output_dim of conv
            # L_out = (L_in + 2 * padding - kernel_size + 1)
            # and padding=(kernel_size - 1) // 2
            # L_out = (L_in + kernel_size - 1 - kernel_size + 1) = L_in
            for kernel_size in kernel_sizes:
                conv = nn.Conv1d(
                    in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2
                )
                convs.append(nn.Sequential(weight_norm(conv), GeLU()))
            # output shape of each conv is (batch_size, out_channels(new), seq_len)
    
            self.model = nn.ModuleList(convs)
    
            self.reset_parameters()
    
        def reset_parameters(self) -> None:
            for seq in self.model:
                conv = seq[0]
                nn.init.normal_(
                    conv.weight,
                    std=math.sqrt(2.0 / (conv.in_channels * conv.kernel_size[0])),
                )
                nn.init.zeros_(conv.bias)
    
        def forward(self, x: Tensor) -> Tensor:
            """
    
            Args:
                x (Tensor): shape (batch_size, embedding_dim, seq_len)
    
            Returns:
                Tensor:
            """
            # back to (batch_size, out_channels, seq_len)
            return torch.cat([encoder(x) for encoder in self.model], dim=1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    out_channels // len(kernel_sizes)将输出大小拆分,最后用torch.cat将它们拼接回out_channels

    padding=(kernel_size - 1) // 2目的是使得卷积后的维度大小和输入的seq_len一致,这里需要kernel_size 是奇数,因为padding只能接收整数。

    weight_norm将权重分解为大小和方向,可以加速训练过程并提高模型的泛化能力。保留原先的权重方向,大小由权重归一化层自己学习:
    w = g v ∣ ∣ v ∣ ∣ \pmb w = g\frac{\pmb v}{||\pmb v||} w=g∣∣v∣∣v

    Encoder实现

    class Encoder(nn.Module):
        def __init__(
            self,
            input_size: int,
            hidden_size: int,
            kernel_sizes: list[int],
            encoder_layers: int,
            dropout: float,
        ) -> None:
            """_summary_
    
            Args:
                input_size (int): embedding_dim or embedding_dim + hidden_size
                hidden_size (int): hidden size
                kernel_sizes (list[int]): the size of kernels
                encoder_layers (int): number of conv layers
                dropout (float): dropout ratio
            """
            super().__init__()
    
            self.encoders = nn.ModuleList(
                [
                    Conv1d(
                        in_channels=input_size if i == 0 else hidden_size,
                        out_channels=hidden_size,
                        kernel_sizes=kernel_sizes,
                    )
                    for i in range(encoder_layers)
                ]
            )
    
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, x: Tensor, mask: Tensor) -> Tensor:
            """forward in encoder
    
            Args:
                x (Tensor): (batch_size, seq_len, input_size)
                mask (Tensor): (batch_size, seq_len, 1)
    
            Returns:
                Tensor: _description_
            """
            # x (batch_size, input_size, seq_len)
            x = x.transpose(1, 2)
            # mask (batch_size, 1, seq_len)
            mask = mask.transpose(1, 2)
    
            for i, encoder in enumerate(self.encoders):
                # fills elements of x with 0.0 where mask is False
                x.masked_fill_(~mask, 0.0)
                # using dropout
                if i > 0:
                    x = self.dropout(x)
                # returned x (batch_size, hidden_size, seq_len)
                x = encoder(x)
    
            # apply dropout
            x = self.dropout(x)
            # (batch_size, seq_len, hidden_size)
            return x.transpose(1, 2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61

    这里用多层Conv1d作为编码器,要注意第0层和其他层的区别,第0层的嵌入维度是input_size即``embedding_size,经过第0层的Conv1d后维度变成两hidden_size,所以后续层参数in_channelshidden_size`。

    这里用x.masked_fill_(~mask, 0.0)设置mask矩阵中的填充位为0。

    不采用RNN作为编码器,作者认为RNN速度慢且没有带来性能上的提升。

    Alignment

    然后实现对齐层,所谓的对齐就是让两个序列进行交互,这里采用基于注意力交互的方式。

    class Alignment(nn.Module):
        def __init__(
            self, input_size: int, hidden_size: int, dropout: float, project_func: str
        ) -> None:
            """
    
            Args:
                input_size (int): embedding_dim  + hidden_size  or embedding_dim  + hidden_size * 2
                hidden_size (int): hidden size
                dropout (float): dropout ratio
                project_func (str): identity or linear
            """
            super().__init__()
    
            self.temperature = nn.Parameter(torch.tensor(1 / math.sqrt(hidden_size)))
    
            if project_func != "identity":
                self.projection = nn.Sequential(
                    nn.Dropout(dropout), Linear(input_size, hidden_size)
                )
            else:
                self.projection = nn.Identity()
    
        def forward(self, a: Tensor, b: Tensor, mask_a: Tensor, mask_b: Tensor) -> Tensor:
            """
    
            Args:
                a (Tensor): (batch_size, seq_len, input_size)
                b (Tensor): (batch_size, seq_len, input_size)
                mask_a (Tensor):  (batch_size, seq_len, 1)
                mask_b (Tensor):  (batch_size, seq_len, 1)
    
            Returns:
                Tensor: _description_
            """
            # if projection == 'linear' : self.projection(*) -> (batch_size, seq_len,  hidden_size) -> transpose(*) -> (batch_size, hidden_size,  seq_len)
            # if projection == 'identity' : self.projection(*) -> (batch_size, seq_len, input_size) -> transpose(*) -> (batch_size, input_size,  seq_len)
            # attn (batch_size, seq_len_a,  seq_len_b)
            attn = (
                torch.matmul(self.projection(a), self.projection(b).transpose(1, 2))
                * self.temperature
            )
            # mask (batch_size, seq_len_a, seq_len_b)
            mask = torch.matmul(mask_a.float(), mask_b.transpose(1, 2).float())
            mask = mask.bool()
            # fills elements of x with 0.0(after exp) where mask is False
            attn.masked_fill_(~mask, -1e7)
            # attn_a (batch_size, seq_len_a,  seq_len_b)
            attn_a = F.softmax(attn, dim=1)
            # attn_b (batch_size, seq_len_a,  seq_len_b)
            attn_b = F.softmax(attn, dim=2)
            # feature_b  (batch_size, seq_len_b,  seq_len_a) x (batch_size, seq_len_a, input_size)
            # -> (batch_size, seq_len_b,  input_size)
            feature_b = torch.matmul(attn_a.transpose(1, 2), a)
            # feature_a  (batch_size, seq_len_a,  seq_len_b) x (batch_size, seq_len_b, input_size)
            # -> (batch_size, seq_len_a,  input_size)
            feature_a = torch.matmul(attn_b, b)
    
            return feature_a, feature_b
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60

    增强残差连接

    image-20231119145423766

    class AugmentedResidualConnection(nn.Module):
        def __init__(self) -> None:
            super().__init__()
    
        def forward(self, x: Tensor, res: Tensor, i: int) -> Tensor:
            """
    
            Args:
                x (Tensor): the output of pre block (batch_size, seq_len, hidden_size)
                res (Tensor): (batch_size, seq_len, embedding_size) or (batch_size, seq_len, embedding_size + hidden_size)
                    res[:,:,hidden_size:] is the output of Embedding layer
                    res[:,:,:hidden_size] is the output of previous two block
                i (int): layer index
    
            Returns:
                Tensor: (batch_size, seq_len,  hidden_size  + embedding_size)
            """
            if i == 1:
                # (batch_size, seq_len,  hidden_size  + embedding_size)
                return torch.cat([x, res], dim=-1)
            hidden_size = x.size(-1)
            # (res[:, :, :hidden_size] + x) is the summation of the output of previous two blocks
            # x (batch_size, seq_len, hidden_size)
            x = (res[:, :, :hidden_size] + x) * math.sqrt(0.5)
            # (batch_size, seq_len,  hidden_size  + embedding_size)
            return torch.cat([x, res[:, :, hidden_size:]], dim=-1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    为了给对齐处理提供更丰富的特征,RE2采用了一个增强版的残差连接,用于每个块之间。

    对于一个长为 l l l的序列,标记第 n n n个块的输入和输出分别为 x ( n ) = ( x 1 ( n ) , x 2 ( n ) , ⋯   , x l ( n ) ) x^{(n)} = (x^{(n)}_1,x^{(n)}_2,\cdots,x^{(n)}_l) x(n)=(x1(n),x2(n),,xl(n)) o ( n ) = ( o 1 ( n ) , o 2 ( n ) , ⋯   , o l ( n ) ) o^{(n)} = (o^{(n)}_1,o^{(n)}_2,\cdots,o^{(n)}_l) o(n)=(o1(n),o2(n),,ol(n)) o ( 0 ) o^{(0)} o(0)表示零向量序列。

    第一个块的输入 x ( 1 ) x^{(1)} x(1)是嵌入层的输出,由图1中的空心矩形表示;第 n ( n ≥ 2 ) n(n\geq 2) n(n2)块的输入 x ( n ) x^{(n)} x(n)是第一块的输入 x ( 1 ) x^{(1)} x(1)和前面两块的输出进行求和后的拼接(图中的对角斜线矩形):
    x i ( n ) = [ x i ( 1 ) ; o i ( n − 1 ) + o i ( n − 2 ) ] x^{(n)}_i =[x^{(1)}_i;o^{(n-1)}_i + o^{(n-2)}_i ] xi(n)=[xi(1);oi(n1)+oi(n2)]
    公式更加清楚一点,第 n n n块的输入是由两个向量拼接而来,第一个向量是第一块的输入,第二个向量是第 n n n块前面两块的输出进行(元素级)累加。这个就是增强的残差连接

    融合层

    class Fusion(nn.Module):
        def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:
            """
    
            Args:
                input_size (int): embedding_dim  + hidden_size  or embedding_dim  + hidden_size * 2
                hidden_size (int): hidden size
                dropout (float): dropout ratio
            """
            super().__init__()
    
            self.dropout = nn.Dropout(dropout)
            self.fusion1 = Linear(input_size * 2, hidden_size, activations=True)
            self.fusion2 = Linear(input_size * 2, hidden_size, activations=True)
            self.fusion3 = Linear(input_size * 2, hidden_size, activations=True)
            self.fusion = Linear(hidden_size * 3, hidden_size, activations=True)
    
        def forward(self, x: Tensor, align: Tensor) -> Tensor:
            """
    
            Args:
                x (Tensor): input (batch_size, seq_len, input_size)
                align (Tensor): output of Alignment (batch_size, seq_len,  input_size)
    
            Returns:
                Tensor: (batch_size, seq_len, hidden_size)
            """
            # x1 (batch_size, seq_len, hidden_size)
            x1 = self.fusion1(torch.cat([x, align], dim=-1))
            # x2 (batch_size, seq_len, hidden_size)
            x2 = self.fusion1(torch.cat([x, x - align], dim=-1))
            # x3 (batch_size, seq_len, hidden_size)
            x3 = self.fusion1(torch.cat([x, x * align], dim=-1))
            # x (batch_size, seq_len, hidden_size * 3)
            x = torch.cat([x1, x2, x3], dim=-1)
            x = self.dropout(x)
            # (batch_size, seq_len, hidden_size)
            return self.fusion(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    融合层通过三个方面比较了局部和对齐表示(分别为对齐层的输入和输出),然后对它们进行融合。

    对于第一个序列来说,融合层的输出 a ˉ \bar a aˉ计算为:
    a ˉ i 1 = G 1 ( [ a i ; a i ′ ] ) , a ˉ i 2 = G 2 ( [ a i ; a i − a i ′ ] ) , a ˉ i 3 = G 3 ( [ a i ; a i ∘ a i ′ ] ) , a ˉ i = G ( [ a ˉ i 1 ; a ˉ i 2 ; a ˉ i 3 ] ) , ˉa1i=G1([ai;ai]),ˉa2i=G2([ai;aiai]),ˉa3i=G3([ai;aiai]),ˉai=G([ˉa1i;ˉa2i;ˉa3i]),

    aˉi1aˉi2aˉi3aˉi=G1([ai;ai]),=G2([ai;aiai]),=G3([ai;aiai]),=G([aˉi1;aˉi2;aˉi3]),
    这里 G 1 , G 2 , G 3 G_1,G_2,G_3 G1,G2,G3 G G G都是参数独立的单层前馈网络; ∘ \circ 表示元素级乘法。

    差操作( − - )强调了两个向量的区别,而乘操作强调了它们的相似。对于另一个序列 b ˉ \bar b bˉ的计算是类似的。

    这些操作和ESIM有点类似,增加了一个前馈网络。

    完了之后通过一个池化层得到定长向量。

    池化层

    class Pooling(nn.Module):
        def forward(self, x: Tensor, mask: Tensor) -> Tensor:
            """
    
            Args:
                x (Tensor): (batch_size, seq_len, hidden_size)
                mask (Tensor): (batch_size, seq_len, 1)
    
            Returns:
                Tensor: (batch_size, hidden_size)
            """
            # max returns a namedtuple (values, indices), we only need values
            return x.masked_fill(~mask, -float("inf")).max(dim=1)[0]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    池化层取时间步维度上的最大值。

    预测层

    class Prediction(nn.Module):
        def __init__(self, hidden_size: int, num_classes: int, dropout: float) -> None:
            super().__init__()
            self.dense = nn.Sequential(
                nn.Dropout(dropout),
                Linear(hidden_size * 4, hidden_size, activations=True),
                nn.Dropout(dropout),
                Linear(hidden_size, num_classes),
            )
    
        def forward(self, a: Tensor, b: Tensor) -> Tensor:
            """
    
            Args:
                a (Tensor): (batch_size, hidden_size)
                b (Tensor): (batch_size, hidden_size)
    
            Returns:
                Tensor: (batch_size, num_classes)
            """
            return self.dense(torch.cat([a, b, a - b, a * b], dim=-1))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    预测层比较简单,再次对输入向量进行了一个融合:
    y ^ = H ( [ v 1 ; v 2 ; v 1 − v 2 ; v 1 ∘ v 2 ] ) \hat {\pmb y} = H([v_1;v_2;v_1-v_2;v1 \circ v_2]) y^=H([v1;v2;v1v2;v1v2])

    RE2实现

    RE2的实现时上述层的堆叠:

    class RE2(nn.Module):
        def __init__(self, args) -> None:
            super().__init__()
    
            self.embedding = Embedding(args.vocab_size, args.embedding_dim, args.dropout)
    
            self.connection = AugmentedResidualConnection()
    
            self.blocks = nn.ModuleList(
                [
                    nn.ModuleDict(
                        {
                            "encoder": Encoder(
                                args.embedding_dim
                                if i == 0
                                else args.embedding_dim + args.hidden_size,
                                args.hidden_size,
                                args.kernel_sizes,
                                args.encoder_layers,
                                args.dropout,
                            ),
                            "alignment": Alignment(
                                args.embedding_dim + args.hidden_size
                                if i == 0
                                else args.embedding_dim + args.hidden_size * 2,
                                args.hidden_size,
                                args.dropout,
                                args.project_func,
                            ),
                            "fusion": Fusion(
                                args.embedding_dim + args.hidden_size
                                if i == 0
                                else args.embedding_dim + args.hidden_size * 2,
                                args.hidden_size,
                                args.dropout,
                            ),
                        }
                    )
                    for i in range(args.num_blocks)
                ]
            )
    
            self.pooling = Pooling()
            self.prediction = Prediction(args.hidden_size, args.num_classes, args.dropout)
    
        def forward(self, a: Tensor, b: Tensor, mask_a: Tensor, mask_b: Tensor) -> Tensor:
            """
            Args:
                a (Tensor): (batch_size, seq_len)
                b (Tensor): (batch_size, seq_len)
                mask_a (Tensor): (batch_size, seq_len, 1)
                mask_b (Tensor): (batch_size, seq_len, 1)
    
            Returns:
                Tensor: (batch_size, num_classes)
            """
            # a (batch_size, seq_len, embedding_dim)
            a = self.embedding(a)
            # b (batch_size, seq_len, embedding_dim)
            b = self.embedding(b)
    
            res_a, res_b = a, b
    
            for i, block in enumerate(self.blocks):
                if i > 0:
                    # a (batch_size, seq_len, embedding_dim + hidden_size)
                    a = self.connection(a, res_a, i)
                    # b (batch_size, seq_len, embedding_dim + hidden_size)
                    b = self.connection(b, res_b, i)
                    # now embeddings saved to res_a[:,:,hidden_size:]
                    res_a, res_b = a, b
                # a_enc (batch_size, seq_len, hidden_size)
                a_enc = block["encoder"](a, mask_a)
                # b_enc (batch_size, seq_len, hidden_size)
                b_enc = block["encoder"](b, mask_b)
                # concating the input and output of encoder
                # a (batch_size, seq_len, embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
                a = torch.cat([a, a_enc], dim=-1)
                # b (batch_size, seq_len, embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
                b = torch.cat([b, b_enc], dim=-1)
                # align_a (batch_size, seq_len,  embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
                # align_b (batch_size, seq_len,  embedding_dim + hidden_size or embedding_dim + hidden_size * 2)
                align_a, align_b = block["alignment"](a, b, mask_a, mask_b)
                # a (batch_size, seq_len,  hidden_size)
                a = block["fusion"](a, align_a)
                # b (batch_size, seq_len,  hidden_size)
                b = block["fusion"](b, align_b)
            # a (batch_size, hidden_size)
            a = self.pooling(a, mask_a)
            # b (batch_size, hidden_size)
            b = self.pooling(b, mask_b)
            # (batch_size, num_classes)
            return self.prediction(a, b)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94

    注意不同块之间输入维度的区别。

    数据准备

    →文章←中数据准备这部分内容有详细的解释。

    from collections import defaultdict
    from tqdm import tqdm
    import numpy as np
    import json
    from torch.utils.data import Dataset
    import pandas as pd
    from typing import Tuple
    
    UNK_TOKEN = ""
    PAD_TOKEN = ""
    
    
    class Vocabulary:
        """Class to process text and extract vocabulary for mapping"""
    
        def __init__(self, token_to_idx: dict = None, tokens: list[str] = None) -> None:
            """
            Args:
                token_to_idx (dict, optional): a pre-existing map of tokens to indices. Defaults to None.
                tokens (list[str], optional): a list of unique tokens with no duplicates. Defaults to None.
            """
    
            assert any(
                [tokens, token_to_idx]
            ), "At least one of these parameters should be set as not None."
            if token_to_idx:
                self._token_to_idx = token_to_idx
            else:
                self._token_to_idx = {}
                if PAD_TOKEN not in tokens:
                    tokens = [PAD_TOKEN] + tokens
    
                for idx, token in enumerate(tokens):
                    self._token_to_idx[token] = idx
    
            self._idx_to_token = {idx: token for token, idx in self._token_to_idx.items()}
    
            self.unk_index = self._token_to_idx[UNK_TOKEN]
            self.pad_index = self._token_to_idx[PAD_TOKEN]
    
        @classmethod
        def build(
            cls,
            sentences: list[list[str]],
            min_freq: int = 2,
            reserved_tokens: list[str] = None,
        ) -> "Vocabulary":
            """Construct the Vocabulary from sentences
    
            Args:
                sentences (list[list[str]]): a list of tokenized sequences
                min_freq (int, optional): the minimum word frequency to be saved. Defaults to 2.
                reserved_tokens (list[str], optional): the reserved tokens to add into the Vocabulary. Defaults to None.
    
            Returns:
                Vocabulary: a Vocubulary instane
            """
    
            token_freqs = defaultdict(int)
            for sentence in tqdm(sentences):
                for token in sentence:
                    token_freqs[token] += 1
    
            unique_tokens = (reserved_tokens if reserved_tokens else []) + [UNK_TOKEN]
            unique_tokens += [
                token
                for token, freq in token_freqs.items()
                if freq >= min_freq and token != UNK_TOKEN
            ]
            return cls(tokens=unique_tokens)
    
        def __len__(self) -> int:
            return len(self._idx_to_token)
    
        def __getitem__(self, tokens: list[str] | str) -> list[int] | int:
            """Retrieve the indices associated with the tokens or the index with the single token
    
            Args:
                tokens (list[str] | str): a list of tokens or single token
    
            Returns:
                list[int] | int: the indices or the single index
            """
            if not isinstance(tokens, (list, tuple)):
                return self._token_to_idx.get(tokens, self.unk_index)
            return [self.__getitem__(token) for token in tokens]
    
        def lookup_token(self, indices: list[int] | int) -> list[str] | str:
            """Retrive the tokens associated with the indices or the token with the single index
    
            Args:
                indices (list[int] | int): a list of index or single index
    
            Returns:
                list[str] | str: the corresponding tokens (or token)
            """
    
            if not isinstance(indices, (list, tuple)):
                return self._idx_to_token[indices]
    
            return [self._idx_to_token[index] for index in indices]
    
        def to_serializable(self) -> dict:
            """Returns a dictionary that can be serialized"""
            return {"token_to_idx": self._token_to_idx}
    
        @classmethod
        def from_serializable(cls, contents: dict) -> "Vocabulary":
            """Instantiates the Vocabulary from a serialized dictionary
    
    
            Args:
                contents (dict): a dictionary generated by `to_serializable`
    
            Returns:
                Vocabulary: the Vocabulary instance
            """
            return cls(**contents)
    
        def __repr__(self):
            return f"{len(self)})>"
    
    
    class TMVectorizer:
        """The Vectorizer which vectorizes the Vocabulary"""
    
        def __init__(self, vocab: Vocabulary, max_len: int) -> None:
            """
            Args:
                vocab (Vocabulary): maps characters to integers
                max_len (int): the max length of the sequence in the dataset
            """
            self.vocab = vocab
            self.max_len = max_len
            self.padding_index = vocab.pad_index
    
        def _vectorize(self, indices: list[int], vector_length: int = -1) -> np.ndarray:
            """Vectorize the provided indices
    
            Args:
                indices (list[int]): a list of integers that represent a sequence
                vector_length (int, optional): an arugment for forcing the length of index vector. Defaults to -1.
    
            Returns:
                np.ndarray: the vectorized index array
            """
    
            if vector_length <= 0:
                vector_length = len(indices)
    
            vector = np.zeros(vector_length, dtype=np.int64)
            if len(indices) > vector_length:
                vector[:] = indices[:vector_length]
            else:
                vector[: len(indices)] = indices
                vector[len(indices) :] = self.padding_index
    
            return vector
    
        def _get_indices(self, sentence: list[str]) -> list[int]:
            """Return the vectorized sentence
    
            Args:
                sentence (list[str]): list of tokens
            Returns:
                indices (list[int]): list of integers representing the sentence
            """
            return [self.vocab[token] for token in sentence]
    
        def vectorize(
            self, sentence: list[str], use_dataset_max_length: bool = True
        ) -> np.ndarray:
            """
            Return the vectorized sequence
    
            Args:
                sentence (list[str]): raw sentence from the dataset
                use_dataset_max_length (bool): whether to use the global max vector length
            Returns:
                the vectorized sequence with padding
            """
            vector_length = -1
            if use_dataset_max_length:
                vector_length = self.max_len
    
            indices = self._get_indices(sentence)
            vector = self._vectorize(indices, vector_length=vector_length)
    
            return vector
    
        @classmethod
        def from_serializable(cls, contents: dict) -> "TMVectorizer":
            """Instantiates the TMVectorizer from a serialized dictionary
    
            Args:
                contents (dict): a dictionary generated by `to_serializable`
    
            Returns:
                TMVectorizer:
            """
            vocab = Vocabulary.from_serializable(contents["vocab"])
            max_len = contents["max_len"]
            return cls(vocab=vocab, max_len=max_len)
    
        def to_serializable(self) -> dict:
            """Returns a dictionary that can be serialized
    
            Returns:
                dict: a dict contains Vocabulary instance and max_len attribute
            """
            return {"vocab": self.vocab.to_serializable(), "max_len": self.max_len}
    
        def save_vectorizer(self, filepath: str) -> None:
            """Dump this TMVectorizer instance to file
    
            Args:
                filepath (str): the path to store the file
            """
            with open(filepath, "w") as f:
                json.dump(self.to_serializable(), f)
    
        @classmethod
        def load_vectorizer(cls, filepath: str) -> "TMVectorizer":
            """Load TMVectorizer from a file
    
            Args:
                filepath (str): the path stored the file
    
            Returns:
                TMVectorizer:
            """
            with open(filepath) as f:
                return TMVectorizer.from_serializable(json.load(f))
    
    
    class TMDataset(Dataset):
        """Dataset for text matching"""
    
        def __init__(self, text_df: pd.DataFrame, vectorizer: TMVectorizer) -> None:
            """
    
            Args:
                text_df (pd.DataFrame): a DataFrame which contains the processed data examples
                vectorizer (TMVectorizer): a TMVectorizer instance
            """
    
            self.text_df = text_df
            self._vectorizer = vectorizer
    
        def __getitem__(
            self, index: int
        ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
            row = self.text_df.iloc[index]
    
            vector1 = self._vectorizer.vectorize(row.sentence1)
            vector2 = self._vectorizer.vectorize(row.sentence2)
    
            mask1 = vector1 != self._vectorizer.padding_index
            mask2 = vector2 != self._vectorizer.padding_index
    
            return (vector1, vector2, mask1, mask2, row.label)
    
        def get_vectorizer(self) -> TMVectorizer:
            return self._vectorizer
    
        def __len__(self) -> int:
            return len(self.text_df)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268

    和之前的文章差不多,唯一的区别增加了填充位置的mask。

    模型训练

    learning_rate=1e-3,
    batch_size=256,
    num_epochs=10,
    max_len=50,
    embedding_dim=300,
    hidden_size=150,
    encoder_layers=2,
    num_blocks=2,
    kernel_sizes=[3],
    dropout=0.2,
    min_freq=2,
    project_func="linear",
    grad_clipping=2.0,
    print_every=300,
    num_classes=2,
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    经过几次实验,表现最好的配置如上所示,学习率为0.2;梯度裁剪为2.0。

    如论文所述,增加了梯度裁剪,学习率指数衰减通过用AdamW替换。

    验证和训练函数为:

    def evaluate(
        data_iter: DataLoader, model: nn.Module
    ) -> Tuple[float, float, float, float]:
        y_list, y_pred_list = [], []
        model.eval()
        for x1, x2, mask1, mask2, y in tqdm(data_iter):
            x1 = x1.to(device).long()
            x2 = x2.to(device).long()
            mask1 = mask1.to(device).bool().unsqueeze(2)
            mask2 = mask2.to(device).bool().unsqueeze(2)
            y = y.float().to(device)
    
            output = model(x1, x2, mask1, mask2)
    
            pred = torch.argmax(output, dim=1).long()
    
            y_pred_list.append(pred)
            y_list.append(y)
    
        y_pred = torch.cat(y_pred_list, 0)
        y = torch.cat(y_list, 0)
        acc, p, r, f1 = metrics(y, y_pred)
        return acc, p, r, f1
    
    
    def train(
        data_iter: DataLoader,
        model: nn.Module,
        criterion: nn.CrossEntropyLoss,
        optimizer: torch.optim.Optimizer,
        grad_clipping: float,
        print_every: int = 500,
        verbose=True,
    ) -> None:
        model.train()
    
        for step, (x1, x2, mask1, mask2, y) in enumerate(tqdm(data_iter)):
            x1 = x1.to(device).long()
            x2 = x2.to(device).long()
            mask1 = mask1.to(device).bool().unsqueeze(2)
            mask2 = mask2.to(device).bool().unsqueeze(2)
            y = torch.LongTensor(y).to(device)
    
            output = model(x1, x2, mask1, mask2)
    
            loss = criterion(output, y)
    
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping)
    
            optimizer.step()
    
            if verbose and (step + 1) % print_every == 0:
                pred = torch.argmax(output, dim=1).long()
                acc, p, r, f1 = metrics(y, pred)
    
                print(
                    f" TRAIN iter={step+1} loss={loss.item():.6f} accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}"
                )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60

    核心训练代码为:

        
    model = RE2(args)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    print(f"Model: {model}")
    
    for epoch in range(args.num_epochs):
        train(
            train_data_loader,
            model,
            criterion,
            optimizer,
            args.grad_clipping,
            print_every=args.print_every,
            verbose=args.verbose,
        )
        print("Begin evalute on dev set.")
        with torch.no_grad():
            acc, p, r, f1 = evaluate(dev_data_loader, model)
    
            print(
                f"EVALUATE [{epoch+1}/{args.num_epochs}]  accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}"
            )
    
    model.eval()
    
    acc, p, r, f1 = evaluate(test_data_loader, model)
    print(f"TEST accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    model = RE2(args)
    
        print(f"Model: {model}")
    
        model_saved_path = os.path.join(args.save_dir, args.model_state_file)
        if args.reload_model and os.path.exists(model_saved_path):
            model.load_state_dict(torch.load(args.model_saved_path))
            print("Reloaded model")
        else:
            print("New model")
    
        model = model.to(device)
    
        model_save_path = os.path.join(
            args.save_dir, f"{datetime.now().strftime('%Y%m%d%H%M%S')}-model.pth"
        )
    
        train_data_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True
        )
        dev_data_loader = DataLoader(dev_dataset, batch_size=args.batch_size)
        test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size)
    
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
        criterion = nn.CrossEntropyLoss()
    
        for epoch in range(args.num_epochs):
            train(
                train_data_loader,
                model,
                criterion,
                optimizer,
                args.grad_clipping,
                print_every=args.print_every,
                verbose=args.verbose,
            )
            print("Begin evalute on dev set.")
            with torch.no_grad():
                acc, p, r, f1 = evaluate(dev_data_loader, model)
    
                print(
                    f"EVALUATE [{epoch+1}/{args.num_epochs}]  accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}"
                )
    
        model.eval()
    
        acc, p, r, f1 = evaluate(test_data_loader, model)
        print(f"TEST accuracy={acc:.3f} precision={p:.3f} recal={r:.3f} f1 score={f1:.4f}")
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    Arguments : Namespace(dataset_csv='text_matching/data/lcqmc/{}.txt', vectorizer_file='vectorizer.json', model_state_file='model.pth', pandas_file='dataframe.{}.pkl', save_dir='D:\\workspace\\nlp-in-action\\text_matching\\re2\\model_storage', reload_model=False, cuda=True, learning_rate=0.001, batch_size=256, num_epochs=10, max_len=50, embedding_dim=300, hidden_size=150, encoder_layers=2, num_blocks=2, kernel_sizes=[3], dropout=0.2, min_freq=2, project_func='linear', grad_clipping=2.0, print_every=300, lr_decay_rate=0.95, num_classes=2, verbose=True)
    Using device: cuda:0.
    Loads cached dataframes.
    Loads vectorizer file.
    Model: RE2(
      (embedding): Embedding(
        (embedding): Embedding(35925, 300, padding_idx=0)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (connection): AugmentedResidualConnection()
      (blocks): ModuleList(
        (0): ModuleDict(
          (encoder): Encoder(
            (encoders): ModuleList(
              (0): Conv1d(
                (model): ModuleList(
                  (0): Sequential(
                    (0): Conv1d(300, 150, kernel_size=(3,), stride=(1,), padding=(1,))
                    (1): GeLU()
                  )
                )
              )
              (1): Conv1d(
                (model): ModuleList(
                  (0): Sequential(
                    (0): Conv1d(150, 150, kernel_size=(3,), stride=(1,), padding=(1,))
                    (1): GeLU()
                  )
                )
              )
            )
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (alignment): Alignment(
            (projection): Sequential(
              (0): Dropout(p=0.2, inplace=False)
              (1): Linear(
                (model): Sequential(
                  (0): Linear(in_features=450, out_features=150, bias=True)
                  (1): GeLU()
                )
              )
            )
          )
          (fusion): Fusion(
            (dropout): Dropout(p=0.2, inplace=False)
            (fusion1): Linear(
              (model): Sequential(
                (0): Linear(in_features=900, out_features=150, bias=True)
                (1): GeLU()
              )
            )
            (fusion2): Linear(
              (model): Sequential(
                (0): Linear(in_features=900, out_features=150, bias=True)
                (1): GeLU()
              )
            )
            (fusion3): Linear(
              (model): Sequential(
                (0): Linear(in_features=900, out_features=150, bias=True)
                (1): GeLU()
              )
            )
            (fusion): Linear(
              (model): Sequential(
                (0): Linear(in_features=450, out_features=150, bias=True)
                (1): GeLU()
              )
            )
          )
        )
        (1): ModuleDict(
          (encoder): Encoder(
            (encoders): ModuleList(
              (0): Conv1d(
                (model): ModuleList(
                  (0): Sequential(
                    (0): Conv1d(450, 150, kernel_size=(3,), stride=(1,), padding=(1,))
                    (1): GeLU()
                  )
                )
              )
              (1): Conv1d(
                (model): ModuleList(
                  (0): Sequential(
                    (0): Conv1d(150, 150, kernel_size=(3,), stride=(1,), padding=(1,))
                    (1): GeLU()
                  )
                )
              )
            )
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (alignment): Alignment(
            (projection): Sequential(
              (0): Dropout(p=0.2, inplace=False)
              (1): Linear(
                (model): Sequential(
                  (0): Linear(in_features=600, out_features=150, bias=True)
                  (1): GeLU()
                )
              )
            )
          )
          (fusion): Fusion(
            (dropout): Dropout(p=0.2, inplace=False)
            (fusion1): Linear(
              (model): Sequential(
                (0): Linear(in_features=1200, out_features=150, bias=True)
                (1): GeLU()
              )
            )
            (fusion2): Linear(
              (model): Sequential(
                (0): Linear(in_features=1200, out_features=150, bias=True)
                (1): GeLU()
              )
            )
            (fusion3): Linear(
              (model): Sequential(
                (0): Linear(in_features=1200, out_features=150, bias=True)
                (1): GeLU()
              )
            )
            (fusion): Linear(
              (model): Sequential(
                (0): Linear(in_features=450, out_features=150, bias=True)
                (1): GeLU()
              )
            )
          )
        )
      )
      (pooling): Pooling()
      (prediction): Prediction(
        (dense): Sequential(
          (0): Dropout(p=0.2, inplace=False)
          (1): Linear(
            (model): Sequential(
              (0): Linear(in_features=600, out_features=150, bias=True)
              (1): GeLU()
            )
          )
          (2): Dropout(p=0.2, inplace=False)
          (3): Linear(
            (model): Sequential(
              (0): Linear(in_features=150, out_features=2, bias=True)
              (1): GeLU()
            )
          )
        )
      )
    )
    New model
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:16<02:38,  4.00it/s] 
    TRAIN iter=300 loss=0.273509 accuracy=0.887 precision=0.885 recal=0.926 f1 score=0.9049
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:31<01:23,  3.99it/s] 
    TRAIN iter=600 loss=0.296151 accuracy=0.859 precision=0.897 recal=0.861 f1 score=0.8784
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:46<00:08,  4.00it/s] 
    TRAIN iter=900 loss=0.262893 accuracy=0.875 precision=0.887 recal=0.887 f1 score=0.8873
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:54<00:00,  3.98it/s]
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.60it/s] 
    EVALUATE [1/10]  accuracy=0.752 precision=0.737 recal=0.783 f1 score=0.7592
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:37,  4.03it/s] 
    TRAIN iter=300 loss=0.272779 accuracy=0.898 precision=0.919 recal=0.907 f1 score=0.9133
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:23,  3.98it/s] 
    TRAIN iter=600 loss=0.238999 accuracy=0.898 precision=0.907 recal=0.930 f1 score=0.9187
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:44<00:08,  4.00it/s] 
    TRAIN iter=900 loss=0.225822 accuracy=0.910 precision=0.929 recal=0.909 f1 score=0.9187
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:52<00:00,  4.01it/s]
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.59it/s] 
    EVALUATE [2/10]  accuracy=0.787 precision=0.763 recal=0.831 f1 score=0.7956
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:37,  4.03it/s] 
    TRAIN iter=300 loss=0.260889 accuracy=0.902 precision=0.929 recal=0.912 f1 score=0.9206
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:22,  4.03it/s] 
    TRAIN iter=600 loss=0.216830 accuracy=0.910 precision=0.929 recal=0.923 f1 score=0.9256
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:43<00:08,  4.06it/s] 
    TRAIN iter=900 loss=0.162659 accuracy=0.945 precision=0.944 recal=0.958 f1 score=0.9510
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:51<00:00,  4.02it/s]
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
    EVALUATE [3/10]  accuracy=0.816 precision=0.809 recal=0.827 f1 score=0.8179
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:36,  4.06it/s] 
    TRAIN iter=300 loss=0.228807 accuracy=0.906 precision=0.909 recal=0.922 f1 score=0.9155
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.05it/s] 
    TRAIN iter=600 loss=0.186292 accuracy=0.926 precision=0.932 recal=0.938 f1 score=0.9347
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.06it/s] 
    TRAIN iter=900 loss=0.160805 accuracy=0.953 precision=0.957 recal=0.957 f1 score=0.9568
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s]
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
    EVALUATE [4/10]  accuracy=0.814 precision=0.804 recal=0.832 f1 score=0.8176
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:13<02:36,  4.06it/s] 
    TRAIN iter=300 loss=0.190363 accuracy=0.910 precision=0.926 recal=0.919 f1 score=0.9226
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.04it/s] 
    TRAIN iter=600 loss=0.190028 accuracy=0.918 precision=0.901 recal=0.967 f1 score=0.9325
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.05it/s] 
    TRAIN iter=900 loss=0.170661 accuracy=0.930 precision=0.957 recal=0.918 f1 score=0.9375
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s]
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
    EVALUATE [5/10]  accuracy=0.810 precision=0.775 recal=0.873 f1 score=0.8212
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:40,  3.95it/s] 
    TRAIN iter=300 loss=0.125980 accuracy=0.965 precision=0.974 recal=0.968 f1 score=0.9709
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.05it/s] 
    TRAIN iter=600 loss=0.160912 accuracy=0.930 precision=0.928 recal=0.953 f1 score=0.9404
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.05it/s] 
    TRAIN iter=900 loss=0.159766 accuracy=0.930 precision=0.922 recal=0.959 f1 score=0.9400
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s] 
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.74it/s] 
    EVALUATE [6/10]  accuracy=0.815 precision=0.777 recal=0.885 f1 score=0.8271
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:13<02:36,  4.04it/s] 
    TRAIN iter=300 loss=0.144144 accuracy=0.941 precision=0.973 recal=0.929 f1 score=0.9508
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:28<01:22,  4.06it/s] 
    TRAIN iter=600 loss=0.149635 accuracy=0.934 precision=0.922 recal=0.975 f1 score=0.9477
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:42<00:08,  4.06it/s] 
    TRAIN iter=900 loss=0.151699 accuracy=0.938 precision=0.926 recal=0.974 f1 score=0.9497
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:50<00:00,  4.04it/s] 
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.73it/s] 
    EVALUATE [7/10]  accuracy=0.831 precision=0.806 recal=0.874 f1 score=0.8383
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:14<02:36,  4.04it/s] 
    TRAIN iter=300 loss=0.191586 accuracy=0.922 precision=0.908 recal=0.967 f1 score=0.9367
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:23,  3.98it/s] 
    TRAIN iter=600 loss=0.188188 accuracy=0.930 precision=0.947 recal=0.935 f1 score=0.9412
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:44<00:08,  4.03it/s] 
    TRAIN iter=900 loss=0.196099 accuracy=0.910 precision=0.939 recal=0.892 f1 score=0.9151
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:53<00:00,  4.00it/s] 
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.66it/s] 
    EVALUATE [8/10]  accuracy=0.838 precision=0.817 recal=0.870 f1 score=0.8426
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:15<02:36,  4.04it/s] 
    TRAIN iter=300 loss=0.136444 accuracy=0.953 precision=0.986 recal=0.934 f1 score=0.9592
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:30<01:22,  4.05it/s] 
    TRAIN iter=600 loss=0.137828 accuracy=0.949 precision=0.953 recal=0.959 f1 score=0.9559
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:45<00:08,  3.98it/s] 
    TRAIN iter=900 loss=0.148434 accuracy=0.934 precision=0.947 recal=0.941 f1 score=0.9439
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:53<00:00,  3.99it/s]
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.39it/s] 
    EVALUATE [9/10]  accuracy=0.840 precision=0.814 recal=0.883 f1 score=0.8471
     32%|█████████████████████████████████████████████████████▌                                                                                                                 | 299/933 [01:15<02:38,  4.01it/s] 
    TRAIN iter=300 loss=0.223042 accuracy=0.918 precision=0.904 recal=0.968 f1 score=0.9350
     64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                           | 599/933 [02:29<01:23,  4.02it/s] 
    TRAIN iter=600 loss=0.105175 accuracy=0.965 precision=0.971 recal=0.964 f1 score=0.9677
     96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 899/933 [03:45<00:08,  4.04it/s] 
    TRAIN iter=900 loss=0.110603 accuracy=0.953 precision=0.934 recal=0.986 f1 score=0.9592
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 933/933 [03:53<00:00,  4.00it/s]
    Begin evalute on dev set.
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 14.66it/s] 
    EVALUATE [10/10]  accuracy=0.836 precision=0.819 recal=0.863 f1 score=0.8406
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 14.59it/s] 
    TEST accuracy=0.822 precision=0.762 recal=0.936 f1 score=0.8403
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257

    这是在没有使用预训练的词向量前提下达到的准确率,后面机会自己训练一个word2vec词向量然后结合起来用看下效果。

    完整代码

    https://github.com/nlp-greyfoss/nlp-in-action-public/blob/master/text_matching/re2/model.py

  • 相关阅读:
    在 Node.js 中使用 MongoDB 事务
    NISP和CISP都有什么同?
    Navicat可视化软件与python第三方pymysql模块
    力扣(LeetCode)222. 完全二叉树的节点个数(2022.08.10)
    【考研408真题】2022年408数据结构41题---判断当前顺序存储结构树是否是二叉搜索树
    [Android]Android P(9) WIFI学习笔记 - 扫描 (3)
    vuex的使用
    JNI动态注册以及JNI签名
    JSP ssh 校园二手商品拍卖系统myeclipse开发mysql数据库MVC模式java编程网页设计
    Socket 服务端实例学习笔记
  • 原文地址:https://blog.csdn.net/yjw123456/article/details/134494932