• Utilizing Transformer Representations Efficiently


    Introduction

    • 在用预训练模型微调时,我们比较习惯于直接用 Transformer 最后一层的输出经过 FC / Bi-LSTM… 后输出最终结果。但实际上,Transformer 的每个层都捕捉的是不同粒度的语言信息 (i.e. with surface features in lower layers, syntactic features in middle layers, and semantic features in higher layers),因此有必要针对不同任务采用不同的 pooling strategy

    在这里插入图片描述


    HuggingFace Transformers 在输入 input_idsattention_mask 后会得到 2 outputs (3 if configured). 下面主要讨论各种 pooling strategy 来综合利用这些输出

    • pooler output [batch_size, hidden_size] - Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining. We can deactivate pooler outputs by setting add pooling layer to False in model config and passing that to model.
    • last hidden state [batch_size, seq_Len, hidden_size] which is the sequence of hidden states at the output of the last layer.
    • hidden states [n_layers, batch_size, seq_len, hidden_size] - Hidden states for all layers and for all ids. (e.g. for base models, n_layers is 1 embed layer + 12 layers = 13. idx 0 is embed layer) Note: To unlock Transformer for giving hidden states as output we need to pass output_hidden_states parameter.

    Different Pooling Strategies

    Pooler Output

    • Pooler output + FC
    logits = nn.Linear(config.hidden_size, 1)(pooler_output) # regression head
    
    • 1

    Last Hidden State Output

    • CLS Embeddings[CLS] Embed + FC
    cls_embeddings = last_hidden_state[:, 0]
    logits = nn.Linear(config.hidden_size, 1)(cls_embeddings) # regression head
    
    • 1
    • 2
    • Mean Pooling:Last Hidden State Output + Mean Pooling (remember to ignore padding tokens using attention masks)
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
    sum_mask = input_mask_expanded.sum(1)
    sum_mask = torch.clamp(sum_mask, min=1e-9)
    mean_embeddings = sum_embeddings / sum_mask
    logits = nn.Linear(config.hidden_size, 1)(mean_embeddings) # regression head
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • Max Pooling:Last Hidden State Output + Max Pooling (remember to ignore padding tokens using attention masks, i.e. simply set masked token embeds’ value to 1e-9)
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    last_hidden_state = last_hidden_state.clone()
    last_hidden_state[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
    max_embeddings = torch.max(last_hidden_state, 1)[0]
    logits = nn.Linear(config.hidden_size, 1)(max_embeddings) # regression head
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • Mean + Max Pooling:(1) Last Hidden State Output + Max Pooling. (2) Last Hidden State Output + Mean Pooling. (3) Concat to have a final representation that is twice the hidden size.
    mean_pooling_embeddings = torch.mean(last_hidden_state, 1)
    _, max_pooling_embeddings = torch.max(last_hidden_state, 1)
    mean_max_embeddings = torch.cat((mean_pooling_embeddings, max_pooling_embeddings), 1)
    logits = nn.Linear(config.hidden_size*2, 1)(mean_max_embeddings)
    
    • 1
    • 2
    • 3
    • 4
    • Conv1D Pooling:Last Hidden State Output + 2 Conv1d layers
    cnn1 = nn.Conv1d(768, 256, kernel_size=2, padding=1)
    cnn2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)
    
    last_hidden_state = last_hidden_state.permute(0, 2, 1)	# [batch_size, embed_size, seq_len]
    cnn_embeddings = F.relu(cnn1(last_hidden_state))	# [batch_size, 256, seq_len]
    cnn_embeddings = cnn2(cnn_embeddings)	# [batch_size, 1, seq_len]
    logits, _ = torch.max(cnn_embeddings, 2)	# [batch_size, 1]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    Hidden States Output

    Motivation

    • The output of the last layer may not always be the best representation of the input text during the fine-tuning for downstrea tasks.
    • For pre-trained language models, including Transformer, the most transferable contextualized representations of input text tend to occur in the middle layers, while the top layers specialize for language modeling. Therefore, the use of the last layer’s output may restrict the power of the pre-trained representation.

    • Layerwise CLS Embeddings:e.g. use second-to-last layer CLS Embeddings
    all_hidden_states = torch.stack(outputs[2])	# [n_layers, batch_size, seq_len, hidden_size]
    cls_embeddings = all_hidden_states[-2, :, 0] # layer_index+1 as we have 13 layers (embedding + num of blocks)
    logits = nn.Linear(config.hidden_size, 1)(cls_embeddings) # regression head
    
    • 1
    • 2
    • 3
    • Concatenate Pooling: Concatenate CLS Embeddings from different layers into one. e.g. Concatenate Last 4 Layers
    all_hidden_states = torch.stack(outputs[2])
    concatenate_pooling = torch.cat(
        (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]),-1
    )
    concatenate_pooling = concatenate_pooling[:, 0]
    logits = nn.Linear(config.hidden_size*4, 1)(concatenate_pooling) # regression head
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • Weighted Layer Pooling: Token embeddings are the weighted mean of their different hidden layer representations. Averaged CLS Embed can be used as the final representation. Weighted Layer Pooling works the best of all pooling techniques be it any given task.
    class WeightedLayerPooling(nn.Module):
        def __init__(self, num_hidden_layers, layer_start: int = 4, layer_weights = None):
            super(WeightedLayerPooling, self).__init__()
            self.layer_start = layer_start
            self.num_hidden_layers = num_hidden_layers
            self.layer_weights = layer_weights if layer_weights is not None \
                else nn.Parameter(
                    torch.tensor([1] * (num_hidden_layers+1 - layer_start), dtype=torch.float)
                )
    
        def forward(self, all_hidden_states):
            all_layer_embedding = all_hidden_states[self.layer_start:, :, :, :]
            weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())	# [n_layer, batch_size, seq_len, embed_dim]
            weighted_average = (weight_factor*all_layer_embedding).sum(dim=0) / self.layer_weights.sum()
            return weighted_average		# [batch_size, seq_len, embed_dim]
        
    layer_start = 9		# 9th ~ 12th hidden layer
    pooler = WeightedLayerPooling(
        config.num_hidden_layers, 
        layer_start=layer_start, layer_weights=None
    )
    weighted_pooling_embeddings = pooler(all_hidden_states)	# [batch_size, seq_len, embed_dim]
    weighted_pooling_embeddings = weighted_pooling_embeddings[:, 0]	# Get CLS Embed. [batch_size, embed_dim]
    logits = nn.Linear(config.hidden_size, 1)(weighted_pooling_embeddings)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • LSTM / GRU Pooling:Use a LSTM network to connect all intermediate representations of the [CLS] token, and the output of the last LSTM cell is used as the final representation.
      o = h L S T M L = L S T M ( h C L S i ) , i ∈ [ 1 , L ] o=h_{L S T M}^L=L S T M\left(h_{C L S}^i\right), i \in[1, L] o=hLSTML=LSTM(hCLSi),i[1,L]
    class LSTMPooling(nn.Module):
        def __init__(self, num_layers, hidden_size, hiddendim_lstm):
            super(LSTMPooling, self).__init__()
            self.num_hidden_layers = num_layers
            self.hidden_size = hidden_size
            self.hiddendim_lstm = hiddendim_lstm
            self.lstm = nn.LSTM(self.hidden_size, self.hiddendim_lstm, batch_first=True)
            self.dropout = nn.Dropout(0.1)
        
        def forward(self, all_hidden_states):
            ## forward
            hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                         for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)		# [batch_size, embed_dim * num_hidden_layers]
            hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
            out, _ = self.lstm(hidden_states, None)
            out = self.dropout(out[:, -1, :])
            return out
    
    hiddendim_lstm = 256
    pooler = LSTMPooling(config.num_hidden_layers, config.hidden_size, hiddendim_lstm)
    lstm_pooling_embeddings = pooler(all_hidden_states)
    logits = nn.Linear(hiddendim_lstm, 1)(lstm_pooling_embeddings) # regression head
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • Attention Pooling:We can use a dot-product attention module to dynamically combine all intermediates
      o = softmax ⁡ ( q h C L S T ) h C L S W h o= \operatorname{softmax}\left(q h_{C L S}^T\right) h_{C L S}W_h o=softmax(qhCLST)hCLSWh其中, h C L S ∈ R n l × d h h_{CLS}\in\R^{n_l\times d_h} hCLSRnl×dh n l n_l nl 个 layer 的 CLS Embeds, q ∈ R 1 × d h , W h ∈ R d h × d f c q\in\R^{1\times d_h},W_h\in\R^{d_h\times d_{fc}} qR1×dh,WhRdh×dfc 为权重
    class AttentionPooling(nn.Module):
        def __init__(self, num_layers, hidden_size, hiddendim_fc):
            super(AttentionPooling, self).__init__()
            self.num_hidden_layers = num_layers
            self.hidden_size = hidden_size
            self.hiddendim_fc = hiddendim_fc
            self.dropout = nn.Dropout(0.1)
    
            q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.hidden_size))
            self.q = nn.Parameter(torch.from_numpy(q_t)).float()
            w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.hidden_size, self.hiddendim_fc))
            self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float()
    
        def forward(self, all_hidden_states):
            hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                         for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)		# [batch_size, embed_dim * num_hidden_layers]
            hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)	# [batch_size, num_hidden_layers, embed_dim]
            out = self.attention(hidden_states)
            out = self.dropout(out)
            return out
    
        def attention(self, h):
            v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1)
            v = F.softmax(v, -1)
            v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1)
            v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2)
            return v
    
    hiddendim_fc = 128
    pooler = AttentionPooling(config.num_hidden_layers, config.hidden_size, hiddendim_fc)
    attention_pooling_embeddings = pooler(all_hidden_states)
    logits = nn.Linear(hiddendim_fc, 1)(attention_pooling_embeddings) # regression head
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    Note: SBERT-WK uses QR decomposition. torch QR decomposition is currently extremely slow when run on GPU. Hence, the tensor is first transferred to the CPU before it is applied. This makes this pooling method rather slow.

    class WKPooling(nn.Module):
        def __init__(self, layer_start: int = 4, context_window_size: int = 2):
            super(WKPooling, self).__init__()
            self.layer_start = layer_start
            self.context_window_size = context_window_size
    
        def forward(self, all_hidden_states):
            ft_all_layers = all_hidden_states
            org_device = ft_all_layers.device
            all_layer_embedding = ft_all_layers.transpose(1,0)
            all_layer_embedding = all_layer_embedding[:, self.layer_start:, :, :]  # Start from 4th layers output
    
            # torch.qr is slow on GPU (see https://github.com/pytorch/pytorch/issues/22573). So compute it on CPU until issue is fixed
            all_layer_embedding = all_layer_embedding.cpu()
    
            attention_mask = features['attention_mask'].cpu().numpy()
            unmask_num = np.array([sum(mask) for mask in attention_mask]) - 1  # Not considering the last item
            embedding = []
    
            # One sentence at a time
            for sent_index in range(len(unmask_num)):
                sentence_feature = all_layer_embedding[sent_index, :, :unmask_num[sent_index], :]
                one_sentence_embedding = []
                # Process each token
                for token_index in range(sentence_feature.shape[1]):
                    token_feature = sentence_feature[:, token_index, :]
                    # 'Unified Word Representation'
                    token_embedding = self.unify_token(token_feature)
                    one_sentence_embedding.append(token_embedding)
    
                ##features.update({'sentence_embedding': features['cls_token_embeddings']})
    
                one_sentence_embedding = torch.stack(one_sentence_embedding)
                sentence_embedding = self.unify_sentence(sentence_feature, one_sentence_embedding)
                embedding.append(sentence_embedding)
    
            output_vector = torch.stack(embedding).to(org_device)
            return output_vector
    
        def unify_token(self, token_feature):
            ## Unify Token Representation
            window_size = self.context_window_size
    
            alpha_alignment = torch.zeros(token_feature.size()[0], device=token_feature.device)
            alpha_novelty = torch.zeros(token_feature.size()[0], device=token_feature.device)
    
            for k in range(token_feature.size()[0]):
                left_window = token_feature[k - window_size:k, :]
                right_window = token_feature[k + 1:k + window_size + 1, :]
                window_matrix = torch.cat([left_window, right_window, token_feature[k, :][None, :]])
                Q, R = torch.qr(window_matrix.T)
    
                r = R[:, -1]
                alpha_alignment[k] = torch.mean(self.norm_vector(R[:-1, :-1], dim=0), dim=1).matmul(R[:-1, -1]) / torch.norm(r[:-1])
                alpha_alignment[k] = 1 / (alpha_alignment[k] * window_matrix.size()[0] * 2)
                alpha_novelty[k] = torch.abs(r[-1]) / torch.norm(r)
    
            # Sum Norm
            alpha_alignment = alpha_alignment / torch.sum(alpha_alignment)  # Normalization Choice
            alpha_novelty = alpha_novelty / torch.sum(alpha_novelty)
    
            alpha = alpha_novelty + alpha_alignment
            alpha = alpha / torch.sum(alpha)  # Normalize
    
            out_embedding = torch.mv(token_feature.t(), alpha)
            return out_embedding
    
        def norm_vector(self, vec, p=2, dim=0):
            ## Implements the normalize() function from sklearn
            vec_norm = torch.norm(vec, p=p, dim=dim)
            return vec.div(vec_norm.expand_as(vec))
    
        def unify_sentence(self, sentence_feature, one_sentence_embedding):
            ## Unify Sentence By Token Importance
            sent_len = one_sentence_embedding.size()[0]
    
            var_token = torch.zeros(sent_len, device=one_sentence_embedding.device)
            for token_index in range(sent_len):
                token_feature = sentence_feature[:, token_index, :]
                sim_map = self.cosine_similarity_torch(token_feature)
                var_token[token_index] = torch.var(sim_map.diagonal(-1))
    
            var_token = var_token / torch.sum(var_token)
            sentence_embedding = torch.mv(one_sentence_embedding.t(), var_token)
    
            return sentence_embedding
        
        def cosine_similarity_torch(self, x1, x2=None, eps=1e-8):
            x2 = x1 if x2 is None else x2
            w1 = x1.norm(p=2, dim=1, keepdim=True)
            w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
            return torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    pooler = WKPooling(layer_start=9)
    wkpooling_embeddings = pooler(all_hidden_states)
    logits = nn.Linear(config.hidden_size, 1)(wkpooling_embeddings) # regression head
    
    • 1
    • 2
    • 3

    More…

    • SWA, Apex AMP & Interpreting Transformers in Torch notebook is an implementation of the Stochastic Weight Averaging technique with NVIDIA Apex on transformers using PyTorch. The notebook also implements how to interactively interpret Transformers using LIT (Language Interpretability Tool) a platform for NLP model understanding.
    • On Stability of Few-Sample Transformer Fine-Tuning notebook goes over various remedies to increase few-sample fine-tuning stability and they show a significant performance improvement over simple finetuning methods.
    • Speeding up Transformer w/ Optimization Strategies notebook explains in-depth 5 optimization strategies with code. All these techniques are promising and can improve the model performance both in terms of speed and accuracy.
    • Other strategies: Dense Pooling, Word Weight (TF-IDF) Pooling, Async Pooling, Parallel / Heirarchical Aggregation

    References

  • 相关阅读:
    php花式读取文件
    采集分析仪设计原理图:437-带触摸显示的10路5Msps@18bit采集分析仪
    深度学习之卷积模型应用
    比特大陆:全员工资停发!昔日的“矿机一哥”遇现金流危机?
    计算机毕业设计ssm高校求职招聘智能推荐1875f系统+程序+源码+lw+远程部署
    Tf铁蛋白颗粒包载顺铂/奥沙利铂/阿霉素/甲氨蝶呤MTX/紫杉醇PTX等药物
    干货 | 5719个字详解低代码在某银行&券商的实践
    Java实现各种加密验证算法(MD5、SHA256、base64、pdkdf2、pdkdf2_sha256)
    0.Linux环境搭建
    Scapy 解析 pcap 文件从HTTP流量中提取图片
  • 原文地址:https://blog.csdn.net/weixin_42437114/article/details/127942235