• 李沐动手学深度学习V2-bert和代码实现


    一. BERT(来自Transformers的双向编码器表示)

    1. 介绍

    BERT通过使用预训练的Transformer编码器,能够基于其双向上下文表示任何词元,在下游任务的监督学习过程中,BERT在两个方面与GPT相似。首先BERT表示将被输入到一个添加的输出层中,根据任务的性质对模型架构进行最小的更改,例如预测每个词元与预测整个序列。其次对预训练Transformer编码器的所有参数进行微调,而额外的输出层将从头开始训练

    2. 输入表示

    在自然语言处理中,有些任务(如情感分析)以单个文本作为输入,而有些任务(如自然语言推断)以一对文本序列作为输入。BERT输入序列明确地表示单个文本和文本对。当输入为单个文本时,BERT输入序列是特殊类别词元“”、文本序列的标记、以及特殊分隔词元“”的连结。当输入为文本对时,BERT输入序列是“”、第一个文本序列的标记、“”、第二个文本序列标记、以及“”的连结。我们将始终如一地将术语“BERT输入序列”与其他类型的“序列”区分开来。例如,一个BERT输入序列可以包括一个文本序列或两个文本序列。为了区分文本对,根据输入序列学到的片段嵌入 e A \mathbf{e}_A eA e B \mathbf{e}_B eB分别被添加到第一序列和第二序列的词元嵌入中。对于单文本输入,仅使用 e A \mathbf{e}_A eA
    下面的get_tokens_and_segments将一个句子或两个句子作为输入,然后返回BERT输入序列及其相应的序列对的片段索引。

    import torch
    import d2l.torch
    from torch import nn
    def get_tokens_segments(tokens_a,tokens_b=None):
        """获取输入序列的词元及其片段索引"""
        tokens = ['']+tokens_a+['']
        # 0和1分别标记片段A和B
        segments = [0]*(len(tokens_a)+2)
        if tokens_b is not None:
            tokens += tokens_b+['']
            segments += [1]*(len(tokens_b)+1)
        return tokens,segments
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    BERT选择Transformer编码器作为其双向架构。在Transformer编码器中位置嵌入被加入到输入序列的每个位置,然而与原始的Transformer编码器不同,BERT使用可学习的位置嵌入总之,bert-input是BERT输入序列的词元嵌入、片段嵌入和位置嵌入的和。如下图所示
    bert输入
    下面的BERTEncoder类与transformer的TransformerEncoder类一样。不同的是,BERTEncoder使用片段嵌入和可学习的位置嵌入。

    class BERTEncoder(nn.Module):
        """BERT编码器"""
        def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len=1000,
                     key_size=768,query_size=768,value_size=768,use_bias=True):
            super(BERTEncoder,self).__init__()
            self.token_embedding = nn.Embedding(vocab_size,num_hiddens)
            self.segment_embedding = nn.Embedding(2,num_hiddens)
            # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数
            self.pos_embedding = nn.Parameter(torch.randn(size=(1,max_len,num_hiddens)))
            self.blks = nn.Sequential()
            for i in range(num_layers):
                self.blks.add_module(f'{i}',d2l.torch.EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias))
        def forward(self,tokens,segments,valid_lens):
            # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)
            X = self.token_embedding(tokens)+self.segment_embedding(segments)
            X += self.pos_embedding.data[:,:X.shape[1],:]
            for blk in self.blks:
                X = blk(X,valid_lens)
            return X
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    假设词表大小为10000,为了演示BERTEncoder的前向推断,创建一个实例并初始化它的参数。

    vocab_size,num_hiddens,ffn_num_input,ffn_num_hiddens,num_heads,num_layers = 1000,768,768,1024,4,2
    norm_shape,dropout = [768],0.2
    encoder = BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout)
    
    • 1
    • 2
    • 3

    将tokens定义为长度为8的2个输入序列,其中每个词元是词表的索引。使用输入tokens的BERTEncoder的前向推断返回编码结果,其中每个词元由向量表示,其长度由超参数num_hiddens定义,此超参数通常称为Transformer编码器的隐藏大小(隐藏单元数)。

    tokens = torch.randint(0,vocab_size,(2,8))
    segments = torch.tensor([[0,0,0,0,1,1,1,1],[0,0,0,1,1,1,1,1]])
    enc_outputs = encoder(tokens,segments,None)
    enc_outputs.shape
    
    • 1
    • 2
    • 3
    • 4
    输出结果如下:
    torch.Size([2, 8, 768])
    
    • 1
    • 2

    3.预训练任务

    BERTEncoder的前向推断给出了输入文本的每个词元和插入的特殊标记“”及“”的BERT表示。接下来将使用这些表示来计算预训练BERT的损失函数。预训练包括以下两个任务:掩蔽语言模型和下一句预测。

    3.1 掩蔽语言模型(Masked Language Modeling)

    为了双向编码上下文以表示每个词元,BERT随机掩蔽词元并使用来自双向上下文的词元以自监督的方式预测掩蔽词元,此任务称为掩蔽语言模型(完形填空)。
    在这个预训练任务中,将随机选择15%的词元作为预测的掩蔽词元。要预测一个掩蔽词元而不使用标签作弊,一个简单的方法是总是用一个特殊的“”替换输入序列中的词元。然而,人造特殊词元“”不会出现在微调中。为了避免预训练和微调之间的这种不匹配,如果为预测而屏蔽词元(例如,在“this movie is great”中选择掩蔽和预测“great”),则在输入中将其替换为:

    • 80%时间为特殊的““词元(例如,“this movie is great”变为“this movie is”;
    • 10%时间为随机词元(例如,“this movie is great”变为“this movie is drink”);
    • 10%时间内为不变的标签词元(例如,“this movie is great”变为“this movie is great”)。
      注意,在20%的时间中,有10%的时间插入了随机词元。这种偶然的噪声鼓励BERT在其双向上下文编码中不那么偏向于掩蔽词元(尤其是当标签词元保持不变时)。
      实现下面的MaskLM类来预测BERT预训练的掩蔽语言模型任务中的掩蔽标记。预测使用单隐藏层的多层感知机(self.mlp)。在前向推断中,它需要两个输入:BERTEncoder的编码结果和用于预测的词元位置,输出是这些位置的预测结果。
    class MaskLM(nn.Module):
        """BERT的掩蔽语言模型任务"""
        def __init__(self,vocab_size,num_hiddens,num_inputs=768,**kwargs):
            super(MaskLM,self).__init__()
            self.mlp = nn.Sequential(nn.Linear(num_inputs,num_hiddens),
                                     nn.ReLU(),
                                     nn.LayerNorm(num_hiddens),
                                     nn.Linear(num_hiddens,vocab_size))
        def forward(self,X,pred_positions):
            num_pred_positions = pred_positions.shape[1]
            pred_positions_id = pred_positions.reshape(-1)
            batch_size = X.shape[0]
            batch_id = torch.arange(0,batch_size)
            batch_idx = torch.repeat_interleave(batch_id,num_pred_positions)
            # 假设batch_size=2,num_pred_positions=3
            # 那么batch_idx是np.array([0,0,0,1,1,1])
            masked_X = X[batch_idx,pred_positions_id]
            masked_X = masked_X.reshape((batch_size,num_pred_positions,-1))
            mlm_Y_hat = self.mlp(masked_X)
            return mlm_Y_hat
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    为了演示MaskLM的前向推断,创建了其实例mlm并对其进行了初始化。将mlm_positions定义为在encoded_X的任一输入序列中预测的3个指示。mlm的前向推断返回encoded_X的所有掩蔽位置mlm_positions处的预测结果mlm_Y_hat。对于每个预测,结果的大小等于词表的大小。

    mlm = MaskLM(vocab_size,num_hiddens)
    pred_positions = torch.tensor([[1,5,2],[6,1,5]])
    mlm_Y_hat = mlm(enc_outputs,pred_positions)
    mlm_Y_hat.shape
    
    • 1
    • 2
    • 3
    • 4
    输出结果如下:
    torch.Size([2, 3, 1000])
    
    • 1
    • 2

    通过掩码下的预测词元mlm_Y的真实标签mlm_Y_hat,可以计算在BERT预训练中的遮蔽语言模型任务的交叉熵损失。

    mlm_Y = torch.tensor([[7,8,9],[10,11,12]])
    loss = nn.CrossEntropyLoss(reduction='none')
    mlm_loss = loss(mlm_Y_hat.reshape((-1,vocab_size)),mlm_Y.reshape(-1))
    print('mlm_loss:',mlm_loss,'\n shape:',mlm_loss.shape)
    
    • 1
    • 2
    • 3
    • 4
    输出结果如下:
    mlm_loss: tensor([7.0592, 6.6084, 6.6682, 6.9848, 8.1962, 6.7052],
           grad_fn=<NllLossBackward0>) 
     shape: torch.Size([6])
    
    • 1
    • 2
    • 3
    • 4

    3.2 下一句预测(Next Sentence Prediction)

    尽管掩蔽语言建模能够编码双向上下文来表示单词,但它不能显式地建模文本对之间的逻辑关系。为了帮助理解两个文本序列之间的关系,BERT在预训练中考虑了一个二元分类任务——下一句预测。在为预训练生成句子对时,有一半的时间它们是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
    下面的NextSentencePred类使用单隐藏层的多层感知机来预测第二个句子是否是BERT输入序列中第一个句子的下一个句子。由于Transformer编码器中的自注意力,特殊词元“”的BERT表示已经对输入的两个句子进行了编码。因此,多层感知机分类器的输出层(self.output)以X作为输入,其中X是多层感知机隐藏层的输出,而MLP隐藏层的输入是编码后的“”词元

    class NextSentencePred(nn.Module):
        """BERT的下一句预测任务"""
        def __init__(self,num_inputs):
            super(NextSentencePred,self).__init__()
            self.output = nn.Linear(num_inputs,2)
        def forward(self,X):
            # X的形状:(batchsize,num_hiddens)
            return self.output(X)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    可以看到,NextSentencePred实例的前向推断返回每个BERT输入序列的二分类预测。

    # NSP的输入形状:(batchsize,num_hiddens)
    nsp = NextSentencePred(enc_outputs.shape[-1])
    nsp_Y_hat = nsp(enc_outputs[:,0,:]) #只把(每个序列第一个词元的特征维度)的特征维度输入nsp中就行
    print('nsp_Y_hat:',nsp_Y_hat,'\nnsp_Y_hat_shape:',nsp_Y_hat.shape)
    
    • 1
    • 2
    • 3
    • 4
    输出结果如下:
    nsp_Y_hat: tensor([[0.4534, 0.2836],
            [0.5663, 0.1450]], grad_fn=<AddmmBackward0>) 
    nsp_Y_hat_shape: torch.Size([2, 2])
    
    • 1
    • 2
    • 3
    • 4

    计算两个二元分类的交叉熵损失。

    nsp_Y = torch.tensor([0,1])
    nsp_loss = loss(nsp_Y_hat,nsp_Y)
    print('nsp_loss:',nsp_loss,'\nnsp_loss_shape:',nsp_loss.shape)
    
    • 1
    • 2
    • 3
    输出结果如下:
    nsp_loss: tensor([0.6119, 0.9258], grad_fn=<NllLossBackward0>) 
    nsp_loss_shape: torch.Size([2])
    
    • 1
    • 2
    • 3

    4. 整合代码

    在预训练BERT时,最终的损失函数是掩蔽语言模型损失函数和下一句预测损失函数的线性组合。现在通过实例化三个类BERTEncoder、MaskLM和NextSentencePred来定义BERTModel类。前向推断返回编码后的BERT表示encoded_X、掩蔽语言模型预测mlm_Y_hat和下一句预测nsp_Y_hat

    class BERTModel(nn.Module):
        """BERT模型"""
        def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len=1000,key_size=768,query_size=768,value_size=768,use_bias=True,hid_in_features=768,mlm_in_features=768,nsp_in_features=768):
            super(BERTModel,self).__init__()
            self.encoder = BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len,key_size,query_size,value_size,use_bias)
            self.mlm = MaskLM(vocab_size,num_hiddens,mlm_in_features)
            self.hidden = nn.Sequential(nn.Linear(hid_in_features,num_hiddens),
                                         nn.Tanh())
            self.nsp = NextSentencePred(nsp_in_features)
        def forward(self,tokens,segments,valid_lens=None,pred_positions=None):
            encoder_X = self.encoder(tokens,segments,valid_lens)
            if pred_positions is not None:
                mlm_Y_hat = self.mlm(encoder_X,pred_positions)
            else:
                mlm_Y_hat = None
            # 用于下一句预测的多层感知机分类器的隐藏层,0是“”标记的索引
            nsp_Y_hat = self.nsp(self.hidden(encoder_X[:,0,:]))
            return encoder_X,mlm_Y_hat,nsp_Y_hat
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    5. 小结

    • word2vec和GloVe等词嵌入模型与上下文无关。它们将相同的预训练向量赋给同一个词,而不考虑词的上下文(如果有的话)。它们很难处理好自然语言中的一词多义或复杂语义。
    • 对于上下文敏感的词表示,如ELMo和GPT,词的表示依赖于它们的上下文。
    • ELMo对上下文进行双向编码,但使用特定于任务的架构(然而,为每个自然语言处理任务设计一个特定的体系架构实际上并不容易);而GPT是任务无关的,但是从左到右编码上下文。
    • BERT结合了这两个方面的优点:它对上下文进行双向编码,并且需要对大量自然语言处理任务进行最小的架构更改。
    • BERT输入序列的嵌入是词元嵌入、片段嵌入和位置嵌入的和。
    • 预训练包括两个任务:掩蔽语言模型和下一句预测。前者能够编码双向上下文来表示单词,而后者则显式地建模文本对之间的逻辑关系。

    6. 全部代码

    import torch
    import d2l.torch
    from torch import nn
    
    
    def get_tokens_segments(tokens_a, tokens_b=None):
        """获取输入序列的词元及其片段索引"""
        tokens = [''] + tokens_a + ['']
        # 0和1分别标记片段A和B
        segments = [0] * (len(tokens_a) + 2)
        if tokens_b is not None:
            tokens += tokens_b + ['']
            segments += [1] * (len(tokens_b) + 1)
        return tokens, segments
    
    
    class BERTEncoder(nn.Module):
        """BERT编码器"""
    
        def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                     dropout, max_len=1000,
                     key_size=768, query_size=768, value_size=768, use_bias=True):
            super(BERTEncoder, self).__init__()
            self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
            self.segment_embedding = nn.Embedding(2, num_hiddens)
            # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数
            self.pos_embedding = nn.Parameter(torch.randn(size=(1, max_len, num_hiddens)))
            self.blks = nn.Sequential()
            for i in range(num_layers):
                self.blks.add_module(f'{i}',
                                     d2l.torch.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                                                            ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))
    
        def forward(self, tokens, segments, valid_lens):
            # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)
            X = self.token_embedding(tokens) + self.segment_embedding(segments)
            X += self.pos_embedding.data[:, :X.shape[1], :]
            for blk in self.blks:
                X = blk(X, valid_lens)
            return X
    
    
    vocab_size, num_hiddens, ffn_num_input, ffn_num_hiddens, num_heads, num_layers = 1000, 768, 768, 1024, 4, 2
    norm_shape, dropout = [768], 0.2
    encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                          dropout)
    tokens = torch.randint(0, vocab_size, (2, 8))
    segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
    enc_outputs = encoder(tokens, segments, None)
    enc_outputs.shape
    
    
    class MaskLM(nn.Module):
        """BERT的掩蔽语言模型任务"""
    
        def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
            super(MaskLM, self).__init__()
            self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                     nn.ReLU(),
                                     nn.LayerNorm(num_hiddens),
                                     nn.Linear(num_hiddens, vocab_size))
    
        def forward(self, X, pred_positions):
            num_pred_positions = pred_positions.shape[1]
            pred_positions_id = pred_positions.reshape(-1)
            batch_size = X.shape[0]
            batch_id = torch.arange(0, batch_size)
            batch_idx = torch.repeat_interleave(batch_id, num_pred_positions)
            # 假设batch_size=2,num_pred_positions=3
            # 那么batch_idx是np.array([0,0,0,1,1,1])
            masked_X = X[batch_idx, pred_positions_id]
            masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
            mlm_Y_hat = self.mlp(masked_X)
            return mlm_Y_hat
    
    
    mlm = MaskLM(vocab_size, num_hiddens)
    pred_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
    mlm_Y_hat = mlm(enc_outputs, pred_positions)
    mlm_Y_hat.shape
    mlm_Y = torch.tensor([[7, 8, 9], [10, 11, 12]])
    loss = nn.CrossEntropyLoss(reduction='none')
    mlm_loss = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
    print('mlm_loss:', mlm_loss, '\n shape:', mlm_loss.shape)
    
    
    class NextSentencePred(nn.Module):
        """BERT的下一句预测任务"""
    
        def __init__(self, num_inputs):
            super(NextSentencePred, self).__init__()
            self.output = nn.Linear(num_inputs, 2)
    
        def forward(self, X):
            # X的形状:(batchsize,num_hiddens)
            return self.output(X)
    
    
    # NSP的输入形状:(batchsize,num_hiddens)
    nsp = NextSentencePred(enc_outputs.shape[-1])
    nsp_Y_hat = nsp(enc_outputs[:, 0, :])  #只把(每个序列第一个词元的特征维度)的特征维度输入nsp中就行
    print('nsp_Y_hat:', nsp_Y_hat, '\nnsp_Y_hat_shape:', nsp_Y_hat.shape)
    nsp_Y = torch.tensor([0, 1])
    nsp_loss = loss(nsp_Y_hat, nsp_Y)
    print('nsp_loss:', nsp_loss, '\nnsp_loss_shape:', nsp_loss.shape)
    
    
    class BERTModel(nn.Module):
        """BERT模型"""
    
        def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                     dropout, max_len=1000, key_size=768, query_size=768, value_size=768, use_bias=True,
                     hid_in_features=768, mlm_in_features=768, nsp_in_features=768):
            super(BERTModel, self).__init__()
            self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                                       num_layers, dropout, max_len, key_size, query_size, value_size, use_bias)
            self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
            self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                        nn.Tanh())
            self.nsp = NextSentencePred(nsp_in_features)
    
        def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
            encoder_X = self.encoder(tokens, segments, valid_lens)
            if pred_positions is not None:
                mlm_Y_hat = self.mlm(encoder_X, pred_positions)
            else:
                mlm_Y_hat = None
            # 用于下一句预测的多层感知机分类器的隐藏层,0是“”标记的索引
            nsp_Y_hat = self.nsp(self.hidden(encoder_X[:, 0, :]))
            return encoder_X, mlm_Y_hat, nsp_Y_hat
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130

    7. 相关链接

    BERT预训练第一篇:李沐动手学深度学习V2-bert和代码实现
    BERT预训练第二篇:李沐动手学深度学习V2-bert预训练数据集和代码实现
    BERT预训练第三篇:李沐动手学深度学习V2-BERT预训练和代码实现
    BERT微调第一篇:李沐动手学深度学习V2-自然语言推断与数据集SNLI和代码实现
    BERT微调第二篇:李沐动手学深度学习V2-BERT微调和代码实现

  • 相关阅读:
    我发布了一款基于RBAC权限模型实现的通用后台管理系统
    解决vue element - ui 弹窗打开表单自动校验问题
    数据库 | VirusCircBase:环状 RNA病毒数据库
    装饰器模式
    threejs的阴影
    java计算机毕业设计交通非现场执法系统源码+mysql数据库+系统+lw文档+部署
    K8S核心概念之SVC(易混淆难理解知识点总结)
    Elasticsearch:如何为 CCR 及 CCS 建立带有安全的集群之间的互信
    DJ12-2-4 串操作指令
    使用curl执行Http请求
  • 原文地址:https://blog.csdn.net/flyingluohaipeng/article/details/126093083