• Prefix-Tuning源码解析


    Prefix-Tuning源码解析

    Prefix-Tuning在PEFT包中的源码实现
    改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py

    import torch
    from transformers import PretrainedConfig
    
    
    class PrefixEncoder(torch.nn.Module):
        r'''
        The torch.nn model to encode the prefix
    
        Input shape: (batch-size, prefix-length)
    
        Output shape: (batch-size, prefix-length, 2*layers*hidden)
        '''
        def __init__(self, config):
            super().__init__()
            self.prefix_projection = config.prefix_projection
            if self.prefix_projection:
                # Use a two-layer MLP to encode the prefix
                self.embedding = torch.nn.Embedding(config.prefix_length, config.hidden_size)
                self.trans = torch.nn.Sequential(
                    torch.nn.Linear(config.hidden_size, config.encoder_hidden_size),
                    torch.nn.Tanh(),
                    torch.nn.Linear(config.encoder_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
                )
            else:
                self.embedding = torch.nn.Embedding(config.prefix_length, config.num_hidden_layers * 2 * config.hidden_size)
    
        def forward(self, prefix: torch.Tensor):
            if self.prefix_projection:
                prefix_tokens = self.embedding(prefix)
                past_key_values = self.trans(prefix_tokens)
            else:
                past_key_values = self.embedding(prefix)
            return past_key_values
        
    
    if __name__ == "__main__":
        configs = {"prefix_length":20,
                   "hidden_size":768,
                   "encoder_hidden_size":768,
                   "num_hidden_layers":12,
                   "prefix_projection":False
                   }
        
    
        prefix_encoder = PrefixEncoder(config=PretrainedConfig.from_dict(configs))
        print(prefix_encoder)
    
        batch_size = 8
        prefix = torch.arange(20).long().expand(batch_size, -1)
        print(prefix.shape)
        output = prefix_encoder(prefix)
        print(output.shape)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53

    下面我们以T5-large模型为例子:
    不考虑Use a two-layer MLP to encode the prefix的话,prefix tuning主要包括以下代码:

    class PrefixEncoder(torch.nn.Module):
        def __init__(self, config):
            super().__init__()
            ...
    		self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) #num_virtual_tokens=20,token_dim=1024,num_layers=24
            
        def forward(self, prefix: torch.Tensor):
            past_key_values = self.embedding(prefix)
            return past_key_values
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    得到的PrefixEncoder被传入peft->peft_model.py->prompt_encoder

    PrefixEncoder(
      (embedding): Embedding(20, 49152) # 1024*2*24
    )
    
    
    • 1
    • 2
    • 3
    • 4

    self.prompt_tokens初始化为长度2*20的向量,因为T5有编码器和解码器,需要两次prefix:

    self.prompt_tokens[adapter_name] = torch.arange(
                config.num_virtual_tokens * config.num_transformer_submodules
            ).long() #20*2
    
    # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
    #        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
    #        36, 37, 38, 39])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    prompt_tokens = (
                self.prompt_tokens[self.active_adapter]
                .unsqueeze(0)
                .expand(batch_size, -1)
                .to(prompt_encoder.embedding.weight.device)
            ) 
    prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
    # 此时prompt_tokens.shape = (batch_size=8, num_virtual_tokens=20)
    
    past_key_values = prompt_encoder(prompt_tokens)
    torch.Size([8, 20, 49152])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    但目前的past_key_values还是所有层的集合,我们需要把past_key_values分解为每一层:

    past_key_values = past_key_values.view(
                    batch_size, #8
                    peft_config.num_virtual_tokens, #20
                    peft_config.num_layers * 2, #24*2
                    peft_config.num_attention_heads, #16
                    peft_config.token_dim // peft_config.num_attention_heads, #1024/16
                )
    # torch.Size([8, 20, 48, 16, 64])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    因为有编码器和解码器,所以再复制一次

    past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
    # torch.Size([8, 20, 96, 16, 64])
    
    # 重排:torch.Size([96, 8, 16, 20, 64])
    # 然后split成一个长度为24的tuple,每个tuple的shape:torch.Size([4, 8, 16, 20, 64])
    past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
                    peft_config.num_transformer_submodules * 2
                )
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    也就是说past_key_values是24个层的Prefix embedding,形状为`(num_transformer_submodules * 2, batch_size, num_attention_heads, num_virtual_tokens, token_dim/num_attention_heads])

    注意这里*2是因为key+value.

    transformers->models->t5->modeling_t5.py->T5Attention类,这里的关键步骤是project函数中的hidden_states = torch.cat([past_key_value, hidden_states], dim=2),注意project函数仅仅用于key和value。

    def forward(
            self,
            hidden_states,
            mask=None,
            key_value_states=None,
            position_bias=None,
            past_key_value=None,
            layer_head_mask=None,
            query_length=None,
            use_cache=False,
            output_attentions=False,
        ):
            """
            Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
            """
            # Input is (batch_size, seq_length, dim)
            # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
            # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
            batch_size, seq_length = hidden_states.shape[:2]
    
            real_seq_length = seq_length
    
            if past_key_value is not None:
                if len(past_key_value) != 2:
                    raise ValueError(
                        f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
                    )
                real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
    
            key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
    
            def shape(states):
                """projection"""
                return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
    
            def unshape(states):
                """reshape"""
                return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
    
            def project(hidden_states, proj_layer, key_value_states, past_key_value):
                """projects hidden states correctly to key/query states"""
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(hidden_states))
                elif past_key_value is None:
                    # cross-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(key_value_states))
    
                if past_key_value is not None:
                    if key_value_states is None:
                        # self-attn
                        # (batch_size, n_heads, key_length, dim_per_head)
                        # 注意这里是重点:用串联方式
                        hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                    elif past_key_value.shape[2] != key_value_states.shape[1]:
                        # checking that the `	sequence_length` of the `past_key_value` is the same as
                        # the provided `key_value_states` to support prefix tuning
                        # cross-attn
                        # (batch_size, n_heads, seq_length, dim_per_head)
                        hidden_states = shape(proj_layer(key_value_states))
                    else:
                        # cross-attn
                        hidden_states = past_key_value
                return hidden_states
    
    
    real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70

    分别计算query_states、key_states、value_states,用query和key计算attention score,得到score形状为torch.Size([8, 16, 2, 22]),所以输入X可以attend to itself以及prefix。

        # hidden_states shape: torch.Size([8, 2, 1024])   
        # get query states
            query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head) 
        # query_states shape: torch.Size([8, 16, 2, 64])
    
            # get key/value states
            key_states = project(
                hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
            )
            # key_states shape: torch.Size([8, 16, 22, 64])
            value_states = project(
                hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
            )
            # value_states shape: torch.Size([8, 16, 22, 64])
            
            # compute scores
            # torch.Size([8, 16, 2, 22])
            scores = torch.matmul(
                query_states, key_states.transpose(3, 2)
            )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    接下来就是经典的attention操作了。用attn_weights ([8, 16, 2, 22]) 和value_states ([8, 16, 22, 64])相乘,把22消掉,就是每个输入X的输出了。

    # if key and values are already calculated
    # we want only the last query position bias
    # position_bias.shape: torch.Size([8, 16, 2, 22])
    
    		scores += position_bias_masked
        	
    
    		attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
                scores
            )  # (batch_size, n_heads, seq_length, key_length)
            attn_weights = nn.functional.dropout(
                attn_weights, p=self.dropout, training=self.training
            )  # (batch_size, n_heads, seq_length, key_length)
    		
            attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim) torch.Size([8, 2, 1024])
            attn_output = self.o(attn_output)
    
            present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
            outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
    
            if output_attentions:
                outputs = outputs + (attn_weights,)
            return outputs
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    参考

    https://huggingface.co/docs/peft/task_guides/seq2seq-prefix-tuning

  • 相关阅读:
    Node.js编程
    你在终端启动的进程,最后都是什么下场?(下)
    高忆管理:尾盘拉升是好事还是坏事?
    Go 并发模型—Goroutines
    高通WLAN框架学习(30)-- 支持双STA的组件
    网络安全(大厂面试真题集)
    人工智能5:构建基于iris 数据集的 SVM 分类模型,含有 iris.csv
    python基础语法(六)
    关于近期轻量化部署任务的一个小结
    重装系统后电脑图片显示不出来怎么办
  • 原文地址:https://blog.csdn.net/weixin_42486623/article/details/133924489