• 【论文解读】GPT Understands, Too


    一.论文

    1.1 P-tuning

    区别于之前的工作,这篇工作认为promote可以在句子中的任意位置起到作用,可以将它们插入上下文或目标中

    上图中,左图是不使用任何操作,右图是选择在居首和目标前插入promote的embedding,插入promote的过程可以表示为

    其中x代表一系列离散的输入令牌,y代表目标(可以理解为希望模型想要给你的回答),e()表示对应的embedding,其实就是将其参数化映射成为伪tokens,即

    通过最小化这些参数

    1.2 promote生成

    嵌入的promote实际上可以理解为不一定离散不相互关联的,而实际上的promote其实应该是高度离散的且具有关联性的,因此作者选择使用双向长短期记忆网络(LSTM),激活函数和MLP来建模这种关系

    在推理中,我们只需要输出嵌入h,并且可以丢弃LSTM头

    二.代码

    本质上是使用一个PromptEncoder来生成伪的embedding添加到原先的embedding中

    2.1 训练

    训练过程只更新promote_encoder中的参数

     2.1.1 PromptEncoder

    PTuneForLAMA中实例化了PromptEncoder

     PromptEncoder本质上是一个(嵌入 + LSTM + MLP)

    1. import torch
    2. import torch.nn as nn
    3. class PromptEncoder(torch.nn.Module):
    4. def __init__(self, template, hidden_size, tokenizer, device, args):
    5. super().__init__()
    6. self.device = device
    7. self.spell_length = sum(template)
    8. self.hidden_size = hidden_size
    9. self.tokenizer = tokenizer
    10. self.args = args
    11. # ent embedding
    12. self.cloze_length = template
    13. self.cloze_mask = [
    14. [1] * self.cloze_length[0] # first cloze
    15. + [1] * self.cloze_length[1] # second cloze
    16. + [1] * self.cloze_length[2] # third cloze
    17. ]
    18. self.cloze_mask = torch.LongTensor(self.cloze_mask).bool().to(self.device)
    19. self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0])))).to(self.device)
    20. # embedding
    21. self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size).to(self.device)
    22. # LSTM
    23. self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size,
    24. hidden_size=self.hidden_size // 2,
    25. num_layers=2,
    26. dropout=self.args.lstm_dropout,
    27. bidirectional=True,
    28. batch_first=True)
    29. self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size),
    30. nn.ReLU(),
    31. nn.Linear(self.hidden_size, self.hidden_size))
    32. print("init prompt encoder...")
    33. def forward(self):
    34. input_embeds = self.embedding(self.seq_indices).unsqueeze(0)
    35. output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()
    36. return output_embeds

    2.1.2 调用

    在PTuneForLAMA的forward函数中调用了embed_input来实现

  • 相关阅读:
    Python网页解析库:用requests-html爬取网页
    C++语言整理(待更新)
    .Net Core之JWT授权
    (2022版)一套教程搞定k8s安装到实战 | Ingress
    学生个人网页设计作品 学生个人网页模板 简单个人主页成品 个人网页制作 HTML学生个人网站作业设计 汉语言文学设计题材网页
    Linux 权限
    Java-------实现类(进阶)
    Python基础dict字典定义与函数
    YZ系列工具之YZ12:VBA_4种方法设计下拉列表
    人工智能 解析解法解决多元线性回归问题
  • 原文地址:https://blog.csdn.net/weixin_50862344/article/details/133962827