• NVIDIA 7th SkyHackathon(四)Nemo ASR 模型训练与评估


    1.模型加载

    1.1 导入 NeMo

    import nemo
    import nemo.collection.asr as nemo_asr
    import torch
    
    # 检查 nemo 版本 '1.4.0'
    print(nemo.__version__)
    
    # 检查 torch 版本 '1.12.1+cu113'
    print(torch.__version__) 
    
    # 检查 GPU 是否被 torch 调用 True
    print(torch.cuda.is_available()) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    1.2 加载自动语音识别模型

    NeMo 的每个集合 ASR、NLP、TTS 中,都提供了许多预训练模型,使用 list_available_models() 可以查看 ASR 所提供的所有预训练模型

    nemo_asr.models.EncDecCTCModel.list_available_models()
    '''
    [PretrainedModelInfo(
      pretrained_model_name=QuartzNet15x5Base-En,
      description=QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.,
      location=https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo
     ),
     ...]
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    本次比赛 NVIDIA 在相关资料中提供了中文预训练模型 stt_zh_quartznet15x5.nemo,使用 restore_form() 进行加载

    # 加载中文预训练模型并实例化
    quartznet = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")
    
    • 1
    • 2

    1.3 加载 quartznet 配置文件

    使用 YAML 读取 quartznet 模型配置文件

    try:
        from ruamel.yaml import YAML
    except ModuleNotFoundError:
        from ruamel_yaml import YAML
    config_path ="quartznet_15x5_zh.yaml"
    
    yaml = YAML(typ='safe')
    with open(config_path) as f:
        params = yaml.load(f)
        
    print(params)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    1.4 加载数据清单并传给配置文件

    将 1.2 节所制作的数据清单,传给配置文件

    # 加载数据清单
    train_manifest = "/root/data/train.json"
    test_manifest = "/root/data/val.json"
    
    # 传递给配置文件
    params['model']['train_ds']['manifest_filepath']=train_manifest
    params['model']['validation_ds']['manifest_filepath']=test_manifest
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    2.模型训练

    2.1 训练

    使用迁移学习的方法训练模型

    # 设置训练集
    quartznet.setup_training_data(train_data_config=params['model']['train_ds'])
    # 设置测试集
    quartznet.setup_validation_data(val_data_config=params['model']['validation_ds'])
    
    # 开始训练
    import pytorch_lightning as pl
    trainer = pl.Trainer(gpus=1,max_epochs=200)
    trainer.fit(quartznet)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    2.2 保存并重载

    # 将训练好的模型保存为.nemo格式
    quartznet.save_to("7th_asr_model_1.nemo")
    
    # 重新加载模型
    try_model_1 = nemo_asr.models.EncDecCTCModel.restore_from("7th_asr_model_1.nemo")
    
    • 1
    • 2
    • 3
    • 4
    • 5

    3 模型评估

    from ASR_metrics import utils as metrics
    
    # 加载测试数据
    asr_result = try_model_1.transcribe(paths2audio_files=["/root/data/test/1/1.wav"])
    print(asr_result)
    
    #指定正确答案
    s1 = "请检测出果皮"
    #识别结果
    s2 = " ".join(asr_result)
    
    #计算字错率cer
    print("字错率:{}".format(metrics.calculate_cer(s1,s2)))
    #计算准确率
    print("准确率:{}".format(1-metrics.calculate_cer(s1,s2)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
  • 相关阅读:
    .env[mode]文件中如何添加注释
    C++ 太卷,转 Java?
    基于LINUX的TCP协WireShark抓包分析
    集成多元算法,打造高效字面文本相似度计算与匹配搜索解决方案,助力文本匹配冷启动[BM25、词向量、SimHash、Tfidf、SequenceMatcher]
    springsecurity 使用浅谈(一)
    智慧城市-疫情流调系列3-Prompt-UIE,生成式通用信息抽取
    11-css3新增选择器
    【性能测试】Cannot assign requested address (Address not available)
    深度学习-学习率调度,正则化,dropout
    java项目-第136期ssm超市收银管理系统-java毕业设计
  • 原文地址:https://blog.csdn.net/u011815404/article/details/128113945