• 【pytorch】深度学习准备:基本配置


    深度学习中常用包

    import os 
    import numpy as np 
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader
    import torch.optim as optimizer
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    超参数设置
    2种设置方式:将超参数直接设置在训练的代码中;用yaml、json,dict等文件来存储超参数

    # 批次的大小
    batch_size = 16
    # 优化器的学习率
    lr = 1e-4
    # 训练次数
    max_epochs = 100
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    GPU设置

    # 方案一:使用os.environ,这种情况如果使用GPU不需要设置
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 指明调用的GPU为0,1号
    
    # 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # 指明调用的GPU为1号
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    使用argparse和yaml文件

    1. argparse的使用:
    import argparse
    """
    	argparse.ArgumentParser()创建了一个对象
    	add_argument()添加参数
    	parse_args()将参数封装在opt内,各个参数通过.运算符调用
    """
    
    def main(opt):
        print(opt.num_batches)
    
    if __name__ == '__main__':
    
        parse = argparse.ArgumentParser()
        parse.add_argument('--num_batches', type=int, default=50, help='the num of batch')
        parse.add_argument('--num_window', type=int, default=5, help='the num of window')
        parse.add_argument('--weight', type=str, default= '../pretrain.pth', help='the path of pretrained model')
    
        opt = parse.parse_args()
        main(opt)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    1. yaml文件的使用
      下面是一个yaml文件的例子,参数呈现层级结构
    device: 'cpu'
    
    data:
        train_path: 'data/train'
        test_path: 'test/train'
        num: 1000
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    读取yaml文件

    def read_yaml(path):
    """
    	read()读入yaml文件中的内容
    	safe_load()加载yaml格式的内容并转换为字典
    """
        file = open(path, 'r', encoding='utf-8')
        string = file.read()
        file.close()
        dict = yaml.safe_load(string)
    
        return dict
    
    path = 'config.yaml'
    Dict = read_yaml(path)
    device = Dict['device']
    print(device)
    train_path = Dict['data']['train_path']
    print(train_path)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    1. 使用方法
      在yaml文件中给全部参数设置默认值,使用argparse库设置待调参数的值

    参考资料

    1. 深度学习代码中的argparse以及yaml文件的使用
    2. datawhale的thorough-pytorch repo
  • 相关阅读:
    Java Pattern.compile()具有什么功能呢?
    机器学习---增量学习
    原型设计模式
    32.4.2 安装JDK
    macOS查看切换当前用户和shell
    C++ 虚函数优化探索简介
    代理技术的演进:从SOCKS到透明代理再到智能HTTP代理
    vue+element实现多级表头加树结构
    HTTP协议解析
    elementUI el-table+树形结构子节点选中后没有打勾?(element版本问题 已解决)
  • 原文地址:https://blog.csdn.net/m0_61819793/article/details/133746850