在使用mmdetection代码库时需要冻结部分网络参数,只训练一部分的网络。这里提供一种简单且不容易出现bug的方法,不仅仅适用于mmdetection代码库,也可以使用在其他的代码库里面,不过需要一定的改动。
mmdetection这个库里面优化器的初始化的位置在:mmdet/apis/train.py,如果使用的是其他代码库的话,需要找到对应的optimizer的初始化的位置。在mmdetection中,原始的optimizer的定义为:
# build optimizer
auto_scale_lr(cfg, distributed, logger)
optimizer = build_optimizer(model, cfg.optimizer)
将这里原本optimizer定义的方式去掉(注释掉),改为新的定义方式。
首先,确定自己需要使用的优化器,一般从config文件里就可以确认。比如mmdetection里面,configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py 中就有:
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
我们根据config文件确认自己需要使用的是SGD优化器之后,直接通过pytorch生成优化器。通过这种方式定义优化器时,需要指定要优化的参数。这样,我们就可以冻结部分参数,对另一部分参数进行优化。具体来说,可以先打印模型中的参数,确认自己要优化的参数的name之后,将其加入到待优化的参数列表里面,如下面的代码所示。
parameters = []
for name, p in model.named_parameters():
print(name)
if "retina_cls" in name:
parameters.append(p)
optimizer = torch.optim.SGD(
parameters, lr=cfg.optimizer['lr'], momentum=cfg.optimizer['momentum'],
weight_decay=cfg.optimizer['weight_decay']
)
完成上面的操作之后,除了parameters里面的参数其他都不会被优化、不会变了。有的同学可能不太放心,那么在实际的model里可以打印参数的具体数值、回传梯度、requires_grad等信息来确认。比如,我使用的是faster RCNN网络,找到forward_train函数。faster RCNN的forward_train函数是定义在其父类TwoStageDetector里(位置在mmdet/models/detectors/two_stage.py)。在loss回传前加入如下代码,打印参数的name、requires_grad以及具体参数数值。
for name, p in self.roi_head.named_parameters():
print(name)
print(p.requires_grad)
print(p)
losses.update(roi_losses)
打印之后可以发现,虽然所有的参数requires_grad都为True,都具有回传梯度,但是除了我们在parameters里面指定的参数,其余的都没有更新(反复运行几次看数值有没有变化即可)。值得注意的是,我们不需要担心其他参数会随着梯度而更新,所以额外再设定其他的参数requires_grad为false了,这样会导致程序出现bug。实际上,其余的参数虽然有回传梯度,但是因为不在optimizer优化范围内,所以不会更新。