• Prompt-Tuning源码分析


    Prompt-Tuning源码分析

    源码

    我们这里的代码解析以huggingface peft源码为主
    从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。

    伪代码示例

    soft_prompt=torch.nn.Parameter(#Make tensor trainable 
    torch.rand(num_tokens,embed_dim))#Initialize soft prompt tensor 
    def input_with_softprompt(x,soft_prompt):
    	x=concatenate([soft_prompt,x] #Prepend soft prompt to input 
    				  dim=seq_len)
    	return x 
    model(input_with_softprompt(x))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    peft源码

    class PromptEmbedding(torch.nn.Module):
        """
    
        ```py
        >>> from peft import PromptEmbedding, PromptTuningConfig
    
        >>> config = PromptTuningConfig(
        ...     peft_type="PROMPT_TUNING",
        ...     task_type="SEQ_2_SEQ_LM",
        ...     num_virtual_tokens=20,
        ...     token_dim=768,
        ...     num_transformer_submodules=1,
        ...     num_attention_heads=12,
        ...     num_layers=12,
        ...     prompt_tuning_init="TEXT",
        ...     prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
        ...     tokenizer_name_or_path="t5-base",
        ... )
    
        >>> # t5_model.shared is the word embeddings of the base model
        >>> prompt_embedding = PromptEmbedding(config, t5_model.shared)
        ```
    
        Input Shape: (`batch_size`, `total_virtual_tokens`)
    
        Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
        """
    
        def __init__(self, config, word_embeddings):
            super().__init__()
    
            total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
            self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
            if config.prompt_tuning_init == PromptTuningInit.TEXT:
                from transformers import AutoTokenizer
    
                tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
                init_text = config.prompt_tuning_init_text
                init_token_ids = tokenizer(init_text)["input_ids"]
                # Trim or iterate until num_text_tokens matches total_virtual_tokens
                num_text_tokens = len(init_token_ids)
                if num_text_tokens > total_virtual_tokens:
                    init_token_ids = init_token_ids[:total_virtual_tokens]
                elif num_text_tokens < total_virtual_tokens:
                    num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
                    init_token_ids = init_token_ids * num_reps
                init_token_ids = init_token_ids[:total_virtual_tokens]
    
                word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
                word_embedding_weights = word_embedding_weights.to(torch.float32)
                self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
    
        def forward(self, indices):
            # Just get embeddings
            prompt_embeddings = self.embedding(indices)
            return prompt_embeddings
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    输出的模型权重文件如下所示:

    /data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
    ├── [ 500]  adapter_config.json
    ├── [ 33K]  adapter_model.bin
    └── [ 111]  README.md
    
    0 directories, 3 files
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。

    推理

    from peft import PeftModel, PeftConfig
    
    peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
    
    # 加载PEFT配置
    config = PeftConfig.from_pretrained(peft_model_id)
    
    # 加载基础模型
    model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
    # 加载PEFT模型
    model = PeftModel.from_pretrained(model, peft_model_id)
    
    # Tokenizer编码
    inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")
    
    # 模型推理
    outputs = model.generate(
            input_ids=inputs["input_ids"], 
            attention_mask=inputs["attention_mask"], 
            max_new_tokens=10, 
            eos_token_id=3
        )
    
    # Tokenizer 解码
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
  • 相关阅读:
    【简答题】JavaWeb必问10道简答题
    智能终端信息安全概念(十):内核安全(2)SElinux
    百度SEO优化不稳定的原因分析(提升网站排名的稳定性)
    图论——有向图强连通分量&无向图双连通分量
    HTML——5.表单、框架、颜色
    企业知识库构建:关于企业知识库及知识平台搭建的重要性!
    网络常见的小知识点
    卷积神经网络 图像分割,卷积神经网络 图像识别
    版本控制 | 想要成为硬件设计高手?最佳实践了解一下!
    不同场景下的JMETER设置
  • 原文地址:https://blog.csdn.net/weixin_42486623/article/details/134028462