• 使用argparse进行调参


    argparse是深度学习项目调参时常用的python标准库,使用argparse后,我们在命令行输入的参数就可以以这种形式python filename.py --lr 1e-4 --batch_size 32来完成对常见超参数的设置。,一般使用时可以归纳为以下三个步骤

    使用步骤:#

    • 创建ArgumentParser()对象
    • 调用add_argument()方法添加参数
    • 使用parse_args()解析参数 在接下来的内容中,我们将以实际操作来学习argparse的使用方法
    import argparse
    
    parser = argparse.ArgumentParser() # 创建一个解析对象
    
    parser.add_argument() # 向该对象中添加你要关注的命令行参数和选项
    
    args = parser.parse_args() # 调用parse_args()方法进行解析
    

    常见规则#

    • 在命令行中输入python demo.py -h或者python demo.py --help可以查看该python文件参数说明
    • arg字典类似python字典,比如arg字典Namespace(integers='5')可使用arg.参数名来提取这个参数
    • parser.add_argument('integers', type=str, nargs='+',help='传入的数字') nargs是用来说明传入的参数个数,'+' 表示传入至少一个参数,'*' 表示参数可设置零个或多个,'?' 表示参数可设置零个或一个
    • parser.add_argument('-n', '--name', type=str, required=True, default='', help='名') required=True表示必须参数, -n表示可以使用短选项使用该参数
    • parser.add_argument("--test_action", default='False', action='store_true')store_true 触发时为真,不触发则为假(test.py,输出为 Falsetest.py --test_action,输出为 True

    使用config文件传入超参数#

    为了使代码更加简洁和模块化,可以将有关超参数的操作写在config.py,然后在train.py或者其他文件导入就可以。具体的config.py可以参考如下内容。

    import argparse  
      
    def get_options(parser=argparse.ArgumentParser()):  
      
        parser.add_argument('--workers', type=int, default=0,  
                            help='number of data loading workers, you had better put it '  
                                  '4 times of your gpu')  
      
        parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')  
      
        parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')  
      
        parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')  
      
        parser.add_argument('--seed', type=int, default=118, help="random seed")  
      
        parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')  
        parser.add_argument('--checkpoint_path',type=str,default='',  
                            help='Path to load a previous trained model if not empty (default empty)')  
        parser.add_argument('--output',action='store_true',default=True,help="shows output")  
      
        opt = parser.parse_args()  
      
        if opt.output:  
            print(f'num_workers: {opt.workers}')  
            print(f'batch_size: {opt.batch_size}')  
            print(f'epochs (niters) : {opt.niter}')  
            print(f'learning rate : {opt.lr}')  
            print(f'manual_seed: {opt.seed}')  
            print(f'cuda enable: {opt.cuda}')  
            print(f'checkpoint_path: {opt.checkpoint_path}')  
      
        return opt  
      
    if __name__ == '__main__':  
        opt = get_options()
    
    $ python config.py
    
    num_workers: 0
    batch_size: 4
    epochs (niters) : 10
    learning rate : 3e-05
    manual_seed: 118
    cuda enable: True
    checkpoint_path:
    

    随后在train.py等其他文件,我们就可以使用下面的这样的结构来调用参数。

    # 导入必要库
    ...
    import config
    
    opt = config.get_options()
    
    manual_seed = opt.seed
    num_workers = opt.workers
    batch_size = opt.batch_size
    lr = opt.lr
    niters = opt.niters
    checkpoint_path = opt.checkpoint_path
    
    # 随机数的设置,保证复现结果
    def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    
    ...
    
    
    if __name__ == '__main__':
      set_seed(manual_seed)
      for epoch in range(niters):
        train(model,lr,batch_size,num_workers,checkpoint_path)
        val(model,lr,batch_size,num_workers,checkpoint_path)
    
    

    参考:

    https://zhuanlan.zhihu.com/p/56922793

    (14条消息) python argparse中action的可选参数store_true的作用_元气少女wuqh的博客-CSDN博客

    [6.6 使用argparse进行调参 — 深入浅出PyTorch (datawhalechina.github.io)](https://datawhalechina.github.io/thorough-pytorch/第六章/6.6 使用argparse进行调参.html)

  • 相关阅读:
    2022-6-2
    Vue3 - 响应式工具函数(使用教程)
    Java-基于SSM+JSP的二手手机回收管理系统
    HTML的基础标签和HTML的Form表单
    年薪中位数超30万,南大AI专业首届毕业生薪资曝光
    Verilog 实现CDC中单bit 跨时钟域,从慢时钟域到快时钟域
    Android Studio真机运行时提示“安装失败”
    HTML模板 宽屏大气的企业官网网站模板
    SpringBoot监控原理、actuator、设置端点
    基于Python实现的图的同构算法
  • 原文地址:https://www.cnblogs.com/qftie/p/16319150.html