• 深度学习训练时冻结部分参数的方法


    问题描述

    在使用mmdetection代码库时需要冻结部分网络参数,只训练一部分的网络。这里提供一种简单且不容易出现bug的方法,不仅仅适用于mmdetection代码库,也可以使用在其他的代码库里面,不过需要一定的改动。

    解决方案

    mmdetection这个库里面优化器的初始化的位置在:mmdet/apis/train.py,如果使用的是其他代码库的话,需要找到对应的optimizer的初始化的位置。在mmdetection中,原始的optimizer的定义为:

    # build optimizer
    auto_scale_lr(cfg, distributed, logger)
    optimizer = build_optimizer(model, cfg.optimizer)
    
    • 1
    • 2
    • 3

    将这里原本optimizer定义的方式去掉(注释掉),改为新的定义方式。

    首先,确定自己需要使用的优化器,一般从config文件里就可以确认。比如mmdetection里面,configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py 中就有:

    optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
    
    • 1

    我们根据config文件确认自己需要使用的是SGD优化器之后,直接通过pytorch生成优化器。通过这种方式定义优化器时,需要指定要优化的参数。这样,我们就可以冻结部分参数,对另一部分参数进行优化。具体来说,可以先打印模型中的参数,确认自己要优化的参数的name之后,将其加入到待优化的参数列表里面,如下面的代码所示。

    parameters = []
    for name, p in model.named_parameters():
        print(name)
        if "retina_cls" in name:
            parameters.append(p)
    
    optimizer = torch.optim.SGD(
        parameters, lr=cfg.optimizer['lr'], momentum=cfg.optimizer['momentum'],
        weight_decay=cfg.optimizer['weight_decay']
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    完成上面的操作之后,除了parameters里面的参数其他都不会被优化、不会变了。有的同学可能不太放心,那么在实际的model里可以打印参数的具体数值、回传梯度、requires_grad等信息来确认。比如,我使用的是faster RCNN网络,找到forward_train函数。faster RCNN的forward_train函数是定义在其父类TwoStageDetector里(位置在mmdet/models/detectors/two_stage.py)。在loss回传前加入如下代码,打印参数的name、requires_grad以及具体参数数值。

    for name, p in self.roi_head.named_parameters():
        print(name)
        print(p.requires_grad)
        print(p)
     
    losses.update(roi_losses)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    打印之后可以发现,虽然所有的参数requires_grad都为True,都具有回传梯度,但是除了我们在parameters里面指定的参数,其余的都没有更新(反复运行几次看数值有没有变化即可)。值得注意的是,我们不需要担心其他参数会随着梯度而更新,所以额外再设定其他的参数requires_grad为false了,这样会导致程序出现bug。实际上,其余的参数虽然有回传梯度,但是因为不在optimizer优化范围内,所以不会更新

  • 相关阅读:
    无涯教程-JavaScript - ASIN函数
    生命在于学习——Cobalt Strike体验
    Streptavidin-MAL,Maleimide 马来酰亚胺修饰/标记/偶联链霉亲和素
    加工制造业智慧采购系统解决方案:助力企业实现全流程采购一体化协同
    誉天在线项目~ElementPlus Tag标签用法
    java-net-php-python-s2s酒店管理系统计算机毕业设计程序
    Prometheus集成consul[被监控对象开启basic认证]
    el-tree 懒加载数据,展开的节点与查询条件联动
    《学术小白学习之路13》基于DTM和主题共现网络——实现主题时序演化网络分析(数据代码在结尾)
    启发式搜索: A*算法
  • 原文地址:https://blog.csdn.net/shaojie_45/article/details/127651382