上篇文章是使用luong的dot计算分数方法实现seq2seq attention简单对话,这篇文件使用general记分方法
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.va=nn.Linear(hidden_size,hidden_size,bias=False)
def forward(self, hidden, encoder_outputs):
"""
hidden:[layer_num,batch_size,hidden_size]
encoder_outputs:[seq_len,batch_size,hidden_size]
"""
score=encoder_outputs.permute(1,0,2).bmm(self.va(hidden).permute(1,2,0))# [batch_size,seq_len,layer_num]
attr=nn.functional.softmax(score,dim=1)# [batch_size,seq_len,layer_num]
context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2))
return context,attr
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.wa=nn.Linear(hidden_size*2,hidden_size*2,bias=False)
self.wa1=nn.Linear(hidden_size*2,hidden_size,bias=False)
def forward(self, hidden, encoder_outputs):
"""
hidden:[layer_num,batch_size,hidden_size]
encoder_outputs:[seq_len,batch_size,hidden_size]
"""
hiddenchange=hidden.repeat(encoder_outputs.size(0),1,1)#[seq_len,batch_size,hidden_size]
concated=torch.cat([hiddenchange.permute(1,0,2),encoder_outputs.permute(1,0,2)],dim=-1)# [batch_size,seq_len,hidden_size*2]
waed=self.wa(concated)# [batch_size,seq_len,hidden_size*2]
tanhed=torch.tanh(waed)# [batch_size,seq_len,hidden_size*2]
self.va=nn.Parameter(torch.FloatTensor(encoder_outputs.size(1),hidden_size*2))#[batch_size,hidden_size*2]
# print("tanhed size",tanhed.size(),self.va.unsqueeze(2).size())
attr=tanhed.bmm(self.va.unsqueeze(2))# [batch_size,seq_len,1]
context=attr.permute(0,2,1).bmm(encoder_outputs.permute(1,0,2))# [batch_size,1,seq_len]
return context,attr
文字太多存不下,放到另外一篇文章了,链接是:https://mp.csdn.net/mp_blog/creation/success/139053790
参见链接https://blog.csdn.net/m0_60688978/article/details/139053661
https://blog.csdn.net/m0_60688978/article/details/139044526