- class StackingGRUCell(nn.Module):
- """
- Multi-layer CRU Cell
- """
- def __init__(self, input_size, hidden_size, num_layers, dropout):
- super(StackingGRUCell, self).__init__()
- self.num_layers = num_layers
- self.grus = nn.ModuleList()
- self.dropout = nn.Dropout(dropout)
-
- self.grus.append(nn.GRUCell(input_size, hidden_size))
- for i in range(1, num_layers):
- self.grus.append(nn.GRUCell(hidden_size, hidden_size))
-
-
- def forward(self, input, h0):
- """
- Input:
- input (batch, input_size): input tensor
- h0 (num_layers, batch, hidden_size): initial hidden state
- ---
- Output:
- output (batch, hidden_size): the final layer output tensor
- hn (num_layers, batch, hidden_size): the hidden state of each layer
- """
- hn = []
- output = input
- for i, gru in enumerate(self.grus):
- hn_i = gru(output, h0[i])
- #在每一次循环中,输入output会经过一个GRU单元并更新隐藏状态
-
- hn.append(hn_i)
- if i != self.num_layers - 1:
- output = self.dropout(hn_i)
- else:
- output = hn_i
- #如果不是最后一层,输出会经过一个dropout层。
-
- hn = torch.stack(hn)
- #将hn列表转变为一个张量
- return output, hn
nn.GRU
中,hn
表示每层的最后一个时间步的隐藏状态。这意味着,对于一个具有seq_len
的输入序列,hn
会包含每层的seq_len
时间步中的最后一个时间步的隐藏状态。StackingGRUCell
中,hn
是通过每层的GRUCell
为给定的单一时间步计算得到的。seq_len
为1,那么nn.GRU
的hn
和StackingGRUCell
的hn
应该是相同的?output更应是如此啥也没有的一个普通GRU:
- class StackingGRU_tst(nn.Module):
- def __init__(self, input_size, hidden_size, num_layers, dropout):
- super(StackingGRU_tst, self).__init__()
- self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)
-
- def forward(self, input, h0):
- output, hn = self.gru(input, h0)
- return output, hn
- input_size = 5
- hidden_size = 10
- num_layers = 3
- dropout = 0.1
- batch_size = 7
- gru_cell_model = StackingGRUCell(input_size, hidden_size, num_layers, dropout)
- gru_cell_model
- '''
- StackingGRUCell(
- (grus): ModuleList(
- (0): GRUCell(5, 10)
- (1): GRUCell(10, 10)
- (2): GRUCell(10, 10)
- )
- (dropout): Dropout(p=0.1, inplace=False)
- )
- '''
-
- gru_model = nn.GRU(input_size, hidden_size, num_layers, dropout=dropout)
- gru_model
- '''
- GRU(5, 10, num_layers=3, dropout=0.1)
- '''
- with torch.no_grad():
- for i in range(num_layers):
- # 对于每一层,复制权重和偏置
- getattr(gru_model, 'weight_ih_l' + str(i)).copy_(gru_cell_model.grus[i].weight_ih)
- getattr(gru_model, 'weight_hh_l' + str(i)).copy_(gru_cell_model.grus[i].weight_hh)
- getattr(gru_model, 'bias_ih_l' + str(i)).copy_(gru_cell_model.grus[i].bias_ih)
- getattr(gru_model, 'bias_hh_l' + str(i)).copy_(gru_cell_model.grus[i].bias_hh)
- input_data = torch.randn(batch_size, input_size)
- h0_cell = torch.randn(num_layers, batch_size, hidden_size)
- h0_gru = h0_cell.clone() # 确保从相同的初始状态开始
由于有dropping的存在,所以每次前向传播之前,都需要设置相同的随机种子
- torch.manual_seed(1215)
- output_cell, hn_cell = gru_cell_model(input_data, h0_cell)
- torch.manual_seed(1215)
- output_gru, hn_gru = gru_model(input_data.unsqueeze(0), h0_gru)
-
- torch.allclose(output_cell, output_gru.squeeze(0)),torch.allclose(hn_cell, hn_gru)
-
- #(True, True)
结果是一样的的,所以似乎论文代码里的stackingGRUCell可以被GRU平替?