码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • RNN/LSTM (二) 实践案例


    文章目录

    • 介绍
    • 主要步骤
    • 踩坑
      • module ‘torchtext.data‘ has no attribute ‘Field
      • 缺少spacy语言模型en_core_web_sm
      • 相关教程
    • 其它RNN案例

    介绍

    本文参考swarnabha/pytorch-text-classification-torchtext-lstm,讲解一个用LSTM训练kaggle数据集的案例。该项目使用torchtext的API构建数据集,而torchtext又使用了spacy库为句子分词。本文会先讲解项目架构,并附上运行代码时可能遇到的踩坑讲解。

    该kaggle页面的data标签页里给出了数据集的下载链接,请下载其中的train.csv和test.csv文件。

    glove前缀的文件,比如"glove.6B.100d.txt"等,都是torchtext的预训练词库。如果你不下载它们,torchtext也会自动下载到.vector_cache目录的。

    主要步骤

    粗略通读项目架构如下:

    第一步,使用scikit-learn的工具方法切分pandas dataframe形式的数据集

    # split data into train and validation
    train_df, valid_df = train_test_split(train)
    print(train_df.head())
    print(valid_df.head())
    
    • 1
    • 2
    • 3
    • 4

    第二步,设置tokenize策略

    TEXT = data.Field(tokenize = 'spacy', include_lengths = True)
    LABEL = data.LabelField(dtype = torch.float)
    
    • 1
    • 2

    利用Field,将Pandas dataframe包装成torchtext dataset

    fields = [('text',TEXT), ('label',LABEL)]
    train_ds, val_ds = DataFrameDataset.splits(fields, train_df=train_df, val_df=valid_df)
    
    • 1
    • 2

    第三步,构建词库,对单词作one-hot编码。

    TEXT.build_vocab(train_ds,
                     max_size = MAX_VOCAB_SIZE,
                     vectors = 'glove.6B.200d',
                     unk_init = torch.Tensor.zero_)
    
    LABEL.build_vocab(train_ds)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    第四步,切分数据集

    train_iterator, valid_iterator = data.BucketIterator.splits(
        (train_ds, val_ds),
        batch_size = BATCH_SIZE,
        sort_within_batch = True,
        device = device)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    最后,循环读取批次,将embedding送入lstm网络。

    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, train_iterator)
    
    • 1
    • 2

    踩坑

    module ‘torchtext.data‘ has no attribute ‘Field

    由于版本兼容性问题,运行代码可能遇到错误AttributeError: module ‘torchtext.data‘ has no attribute ‘Field‘,也可以参考attributeerror-module-torchtext-data-has-no-attribute-field。使用torchtext 0.10(可能会安装旧版的pytorch,所以用conda开个新环境,凑合着用吧),然后from torchtext import data改成from torchtext.legacy import data。

    阅读torchtext的版本更新与api变迁可以得知APi变迁。

    • 在0.8版本以前,是from torchtext import data。
    • 在0.9到0.12版本之间,是from torchtext.legacy import data。
    • 在0.12版本之后,该data库已经删除,如果坚持要用,需要参考新版API教程。

    总之,要复用教程的API,最好用torchtext 0.9或0.10。pytorch-sentiment-analysis列出了一些基于该版本API的教程,可以参考它的第一个教程运行下。

    缺少spacy语言模型en_core_web_sm

    如果运行遇到以下问题,说明需要下载spacy语言模型en_core_web_sm并安装。

    Can’t find model ‘en_core_web_sm’. It doesn’t seem to be a Python package or a valid path to a data directory.

    这是因为自然语言处理类库spacy需要加载语言模型en_core_web_sm。尽管官方教程 Quickstart会建议你运行以下命令:

    python -m spacy download en_core_web_sm
    
    • 1

    但是因为无法连接外网的缘故,这条命令大概率会连接失败(你可以开启代理再尝试,但笔者并没成功)。

    所以我参考NLP Spacy中en_core_web_sm安装问题,及最新版下载地址,做法如下:

    1. 到github的release界面搜索"en_core_web_sm",找最新版的压缩包下载
    2. 用pip install 安装这个压缩包。

    笔者下载的是3.4.1版本。如果下载2.5.2,可能会出现cannot read config之类的错误。笔者不知道怎么解决。

    相关教程

    • 为了更好理解torchtext0.9~0.12版本下的torchtext.legacy.data包,可以学习bentrevett/pytorch-sentiment-analysis的教程1,其中有提到了Field是如何帮助构建vocab的,并在这期间对句子的单词作清洗工作(为了减少要训练的embedding的数,保留最高频率的单词)。
    • 文章使用的是自定义的torchtext Dataset,如果想了解自定义pytorch Dataset的用法,可学习pytorch dataset

    其它RNN案例

    在网络上可以看到不少用RNN搭建的模型:

    • pytorch官方的教程,SEQUENCE MODELS AND LONG SHORT-TERM MEMORY NETWORKS 的模型代码很清晰,架构完整,但是缺乏训练数据集(training_data变量的数据很匮乏),所以只能个人理解。
    • towardsdatascience的教程,lstm-text-classification-using-pytorch,其训练集存在kaggle链接
  • 相关阅读:
    【CIKM 2023】扩散模型加速采样算法OLSS,大幅提升模型推理速度
    区块链在公益活动平台中的应用研究
    【微服务】SpringCloud中Ribbon的WeightedResponseTimeRule策略
    Ubuntu18.04切换Python版本
    iOS图片占内存大小与什么有关?
    基于Java+微信小程序实现《购物商城系统》
    【光学】Matlab实现色散曲线拟合
    【UE5】游戏框架GamePlay
    视频汇聚平台EasyCVR对接GA/T 1400视图库:结构化数据(人员/人脸、车辆、物品)对象XMLSchema描述
    bit和B
  • 原文地址:https://blog.csdn.net/duoyasong5907/article/details/128125570
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号