本文参考swarnabha/pytorch-text-classification-torchtext-lstm,讲解一个用LSTM训练kaggle数据集的案例。该项目使用torchtext的API构建数据集,而torchtext又使用了spacy库为句子分词。本文会先讲解项目架构,并附上运行代码时可能遇到的踩坑讲解。
该kaggle页面的data标签页里给出了数据集的下载链接,请下载其中的train.csv和test.csv文件。
glove前缀的文件,比如"glove.6B.100d.txt"等,都是torchtext的预训练词库。如果你不下载它们,torchtext也会自动下载到
.vector_cache
目录的。
粗略通读项目架构如下:
第一步,使用scikit-learn的工具方法切分pandas dataframe形式的数据集
# split data into train and validation
train_df, valid_df = train_test_split(train)
print(train_df.head())
print(valid_df.head())
第二步,设置tokenize策略
TEXT = data.Field(tokenize = 'spacy', include_lengths = True)
LABEL = data.LabelField(dtype = torch.float)
利用Field,将Pandas dataframe包装成torchtext dataset
fields = [('text',TEXT), ('label',LABEL)]
train_ds, val_ds = DataFrameDataset.splits(fields, train_df=train_df, val_df=valid_df)
第三步,构建词库,对单词作one-hot编码。
TEXT.build_vocab(train_ds,
max_size = MAX_VOCAB_SIZE,
vectors = 'glove.6B.200d',
unk_init = torch.Tensor.zero_)
LABEL.build_vocab(train_ds)
第四步,切分数据集
train_iterator, valid_iterator = data.BucketIterator.splits(
(train_ds, val_ds),
batch_size = BATCH_SIZE,
sort_within_batch = True,
device = device)
最后,循环读取批次,将embedding送入lstm网络。
for epoch in range(num_epochs):
train_loss, train_acc = train(model, train_iterator)
由于版本兼容性问题,运行代码可能遇到错误AttributeError: module ‘torchtext.data‘ has no attribute ‘Field‘,也可以参考attributeerror-module-torchtext-data-has-no-attribute-field。使用torchtext 0.10(可能会安装旧版的pytorch,所以用conda开个新环境,凑合着用吧),然后from torchtext import data
改成from torchtext.legacy import data
。
阅读torchtext的版本更新与api变迁可以得知APi变迁。
from torchtext import data
。from torchtext.legacy import data
。总之,要复用教程的API,最好用torchtext 0.9或0.10。pytorch-sentiment-analysis列出了一些基于该版本API的教程,可以参考它的第一个教程运行下。
如果运行遇到以下问题,说明需要下载spacy语言模型en_core_web_sm并安装。
Can’t find model ‘en_core_web_sm’. It doesn’t seem to be a Python package or a valid path to a data directory.
这是因为自然语言处理类库spacy需要加载语言模型en_core_web_sm。尽管官方教程 Quickstart会建议你运行以下命令:
python -m spacy download en_core_web_sm
但是因为无法连接外网的缘故,这条命令大概率会连接失败(你可以开启代理再尝试,但笔者并没成功)。
所以我参考NLP Spacy中en_core_web_sm安装问题,及最新版下载地址,做法如下:
pip install
安装这个压缩包。笔者下载的是3.4.1版本。如果下载2.5.2,可能会出现cannot read config之类的错误。笔者不知道怎么解决。
在网络上可以看到不少用RNN搭建的模型:
training_data
变量的数据很匮乏),所以只能个人理解。