• python-pytorch 实现seq2seq+luong general concat attention笔记1.0.10


    python-pytorch 实现seq2seq+luong general concat attention笔记1.0.10

    上篇文章是使用luong的dot计算分数方法实现seq2seq attention简单对话,这篇文件使用general记分方法

    只需要替换Attention类

    1. general注意力
      就是在dot方法基础上,对hidden做一个线性变换
    class Attention(nn.Module):
        def __init__(self):
            super(Attention, self).__init__()
            self.va=nn.Linear(hidden_size,hidden_size,bias=False)
            
        def forward(self, hidden, encoder_outputs):
            """
            hidden:[layer_num,batch_size,hidden_size]
            encoder_outputs:[seq_len,batch_size,hidden_size]
            """
            score=encoder_outputs.permute(1,0,2).bmm(self.va(hidden).permute(1,2,0))# [batch_size,seq_len,layer_num]
            
            attr=nn.functional.softmax(score,dim=1)# [batch_size,seq_len,layer_num]
            
            context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2))
    
            
            return context,attr
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    1. concat注意力
      这里需要注意的是,在计算出tanh后,需要自定义个va的矩阵相乘,大小是当前的[batch_size,hidden_size*2]
    class Attention(nn.Module):
        def __init__(self):
            super(Attention, self).__init__()
            self.wa=nn.Linear(hidden_size*2,hidden_size*2,bias=False)
            self.wa1=nn.Linear(hidden_size*2,hidden_size,bias=False)
            
        def forward(self, hidden, encoder_outputs):
            """
            hidden:[layer_num,batch_size,hidden_size]
            encoder_outputs:[seq_len,batch_size,hidden_size]
            """
            
            hiddenchange=hidden.repeat(encoder_outputs.size(0),1,1)#[seq_len,batch_size,hidden_size]
            
            concated=torch.cat([hiddenchange.permute(1,0,2),encoder_outputs.permute(1,0,2)],dim=-1)# [batch_size,seq_len,hidden_size*2]
            waed=self.wa(concated)# [batch_size,seq_len,hidden_size*2]
            tanhed=torch.tanh(waed)# [batch_size,seq_len,hidden_size*2]
            
            self.va=nn.Parameter(torch.FloatTensor(encoder_outputs.size(1),hidden_size*2))#[batch_size,hidden_size*2]
    #         print("tanhed size",tanhed.size(),self.va.unsqueeze(2).size())
            
            attr=tanhed.bmm(self.va.unsqueeze(2))# [batch_size,seq_len,1]
            
            context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2))# [batch_size,1,seq_len]
    
            return context,attr
            
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    三者训练结果对比

    文字太多存不下,放到另外一篇文章了,链接是:https://mp.csdn.net/mp_blog/creation/success/139053790

    总结

    • 将线性变换设置bias为True的时候,收敛会变快
    • 之前学习率设置是0.001,然后改为0.05后,起始loss变为1.xxxx,4000次迭代loss就达到了0.00x

    完整代码

    参见链接https://blog.csdn.net/m0_60688978/article/details/139053661

    参考

    https://blog.csdn.net/m0_60688978/article/details/139044526

  • 相关阅读:
    招投标系统简介 招投标系统源码 java招投标系统 招投标系统功能设计
    详细介绍如何使用YOLOV8和KerasCV进行高效物体检测
    LeetCode 958. 二叉树的完全性校验
    J2EE--通用分页
    心电信号导出呼吸频率的算法
    学习笔记-sliver
    【深入浅出Spring原理及实战】「源码调试分析」深入源码探索Spring底层框架的的refresh方法所出现的问题和异常
    源码编译安装 LAMP
    es(Elasticsearch)客户端Kibana安装使用(03Kibana安装篇)
    终于把相册集成到摄像头APP
  • 原文地址:https://blog.csdn.net/m0_60688978/article/details/139046644