在pytorch中调用RNN模型的时候,使用self.rnn = nn.RNN(embedding_num,hidden_num)往往忽略了其中的一个参数,在点开RNN源码的时候,可以看到其中batch_first这个参数:
可以看到这个参数如果为True的话,你的输入输出的tensor形状为(batch,maxlen,embedding_num),但是这个参数默认是False的,所以如果忘记了这个参数,要把输入的batch_idx在多加一行transpose的代码,如下:
- batch_text_idx = batch_text_idx.transpose(1,0,2)
- #也就是说将原来的batch 和 maxlen 要对调
如果说直接设置为True的话,就不用去做转置的部分。