对于输入序列中的每个元素,每层计算以下函数:
i
t
=
σ
(
W
i
i
x
t
+
b
i
i
+
W
h
i
h
t
−
1
+
b
h
i
)
i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1}+b_{hi})
it=σ(Wiixt+bii+Whiht−1+bhi)
f
t
=
σ
(
W
i
f
x
t
+
b
i
f
+
W
h
f
h
t
−
1
+
b
h
f
)
f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1}+b_{hf})
ft=σ(Wifxt+bif+Whfht−1+bhf)
o
t
=
σ
(
W
i
o
x
t
+
b
i
o
+
W
h
o
h
t
−
1
+
b
h
o
)
o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1}+b_{ho})
ot=σ(Wioxt+bio+Whoht−1+bho)
g
t
=
t
a
n
h
(
W
i
g
x
t
+
b
i
g
+
W
h
g
h
t
−
1
+
b
h
g
)
g_t=tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})
gt=tanh(Wigxt+big+Whght−1+bhg)
c
t
=
f
t
⊙
c
t
−
1
+
i
t
⊙
g
t
c_t=f_t \odot c_{t-1} + i_t \odot g_t
ct=ft⊙ct−1+it⊙gt
h
t
=
o
t
⊙
t
a
n
h
(
c
t
)
h_t = o_t \odot tanh(c_t)
ht=ot⊙tanh(ct)
其中各个变量的含义如下:
import torch.nn as nn
import torch
rnn = nn.LSTM(10, 20, 2)# embedding_size, hidden_size, num_layer
input = torch.randn(5, 3, 10)# sequence length, batch size, embedding_size
h0 = torch.randn(2, 3, 20)# num_layer*dirc, batch size, hidden_size
c0 = torch.randn(2, 3, 20)# num_layer*dirc, batch size, hidden_size
output, (hn, cn) = rnn(input, (h0, c0))
output.shape
Out[8]: torch.Size([5, 3, 20])# # sequence length, batch size, hidden_size
hn.shape
Out[9]: torch.Size([2, 3, 20])# num_layer*dirc, batch size, hidden_size
c0.shape
Out[10]: torch.Size([2, 3, 20])# num_layer*dirc, batch size, hidden_size
rnn = nn.LSTM(input_size=1, hidden_size=20, num_layers=2)
input = torch.tensor([[1,2,0], [3,0,0], [4,5,6]], dtype=torch.float)
lens = [2, 1, 3]
# 构建输入数据,维度为:torch.Size([3, 3, 1]), 即 bactch_size=3, sequence length=3, embedding size=1
input = input.unsqueeze(2)
input
Out[68]:
tensor([[[1.],
[2.],
[0.]],
[[3.],
[0.],
[0.]],
[[4.],
[5.],
[6.]]])
# 第一维是 batch,则batch_first=True,
padded_seq = pack_padded_sequence(input, lens, batch_first=True, enforce_sorted=False)
# 将 padded_seq输入,并且不对hidden和cell进行初始化
output, (hn, cn) = rnn(padded_seq)
# 进行逆操作拆箱
output = pad_packed_sequence(output, batch_first=True)
# output[0] LSTM输出,output[1]为batch中样本长度
output[0].shape
Out[72]: torch.Size([3, 3, 20])
output[1]
Out[73]: tensor([2, 1, 3])
hn.shape
Out[76]: torch.Size([2, 3, 20])
cn.shape
Out[77]: torch.Size([2, 3, 20])
未完待续。。。