• 【论文解读】The Power of Scale for Parameter-Efficient Prompt Tuning


    一.介绍

    1.1 promote tuning 和 prefix tuning 的关系

    “前缀调优”的简化版

    1.2 大致实现

    冻结了整个预训练模型,并且只允许每个下游任务附加k个可调令牌到输入文本。这种“软提示”是端到端训练的,可以压缩来自完整标记数据集的信号,使我们的方法优于少量提示,并通过模型调整缩小质量差距。同时,由于单个预训练模型可用于所有下游任务,因此我们保留了冻结模型的高效服务优势

    1.3 核心贡献

    1. 提出提示调优,并在大型语言模型中展示其与模型调优的竞争力。
    2. 消除许多设计选择,并显示质量和健壮性随着规模而提高。
    3. 在域移位问题上显示提示调优优于模型调优。
    4. 提出“即时整合”并展示其有效性。

    二.promote tuning

    2.1 问题建模

    将所有任务都转换为文本生成。将分类建模为给定某些输入的输出类的概率P_r(y|X),其中X是一系列标记,y是单个类标签,现在我们将其建模为条件生成,其中y是表示类标签的标记序列。

    2.2 promote 如何work的

    提示是在Y生成过程中为模型添加额外信息的方法。通常,提示是通过在输入X前添加一系列标记P来完成的,这样模型就可以最大化生成Y的正确Y的可能性。通常,提示是通过在输入X前添加一系列标记P来完成的,这样模型就可以最大化正确Y的可能性,Pr_{\theta } (Y|[P;X]),同时保持模型参数θ不变。

    提示调优本质上就是使用专用参数\theta _p建模promote信息作为提示符,这些提示符被连接到嵌入的输入,直接通过模型(encoder-decoder架构)

    2.3 与其他工作的对比

    文章第四节对比了该方法和其他方法的异同,但是没有给出数据对比

    三.代码实现

    【pytorch参考代码】

    只训练soft promote 权重

    1. # Only update soft prompt'weights for prompt-tuning. ie, all weights in LM are set as `require_grad=False`.
    2. optimizer_grouped_parameters = [
    3. {
    4. "params": [p for n, p in model.named_parameters() if n == "soft_prompt.weight"],
    5. "weight_decay": args.weight_decay,
    6. }
    7. ]
    8. optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    9. lr_scheduler = get_scheduler(
    10. name=args.lr_scheduler_type,
    11. optimizer=optimizer,
    12. num_warmup_steps=args.num_warmup_steps,
    13. num_training_steps=args.max_train_steps,
    14. )

    初始化权重

    1. def initialize_soft_prompt(
    2. self,
    3. n_tokens: int = 20,
    4. initialize_from_vocab: bool = True,
    5. random_range: float = 0.5,
    6. ) -> None:
    7. self.n_tokens = n_tokens
    8. if initialize_from_vocab:
    9. init_prompt_value = self.transformer.wte.weight[:n_tokens].clone().detach()
    10. else:
    11. init_prompt_value = torch.FloatTensor(2, 10).uniform_(
    12. -random_range, random_range
    13. )
    14. self.soft_prompt = nn.Embedding(n_tokens, self.config.n_embd)
    15. # Initialize weight
    16. self.soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)

    处理输入

    1. def _cat_learned_embedding_to_input(self, input_ids) -> torch.Tensor:
    2. inputs_embeds = self.transformer.wte(input_ids)
    3. if len(list(inputs_embeds.shape)) == 2:
    4. inputs_embeds = inputs_embeds.unsqueeze(0)
    5. # [batch_size, n_tokens, n_embd]
    6. learned_embeds = self.soft_prompt.weight.repeat(inputs_embeds.size(0), 1, 1)
    7. inputs_embeds = torch.cat([learned_embeds, inputs_embeds], dim=1)
    8. return inputs_embeds

  • 相关阅读:
    Find My手机保护壳|苹果Find My与手机保护壳结合,智能防丢,全球定位
    解决docker安装mysql的乱码问题
    常用Linux命令
    【C++】多态/虚表
    Linux aarch64交叉编译之 weston窗口管理程序
    TCP详解
    内置指令、自定义指令(详细)、全局指令与局部指令
    protobuf的复杂结构
    Teams Tab App 的 manifest 分析
    get_started_3dsctf_2016【BUUCTF】(两种解法)
  • 原文地址:https://blog.csdn.net/weixin_50862344/article/details/133953048