• 【目标检测】YOLOv5模型从大变小,发生了什么?


    前言

    在某次使用YOLOv5进行实验时,看到模型已经收敛得差不多,于是想提前停止训练,就果断直接终止程序。然而在查看文件大小时,突然发现,正常训练的yolov5m模型大小为40M左右,而此时生成的yolov5m模型大小达到了160M,于是产生如题疑问:模型从大变小,发生了什么?

    问题根源

    回到train.py这个文件,发现在模型训练完成之后,还存在这样一段代码:

    if rank in [-1, 0]:
         # Plots
         if plots:
             plot_results(save_dir=save_dir)  # save as results.png
    
         # Test best.pt
         logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
         if opt.data.endswith('coco.yaml') and nc == 80:  # if COCO
             for m in (last, best) if best.exists() else (last):  # speed, mAP tests
                 results, _, _ = test.test(opt.data,
                                           batch_size=batch_size * 2,
                                           imgsz=imgsz_test,
                                           conf_thres=0.001,
                                           iou_thres=0.7,
                                           model=attempt_load(m, device).half(),
                                           single_cls=opt.single_cls,
                                           dataloader=testloader,
                                           save_dir=save_dir,
                                           save_json=True,
                                           plots=False,
                                           is_coco=is_coco)
    
         # Strip optimizers
         final = best if best.exists() else last  # final model
         for f in last, best:
             if f.exists():
                 strip_optimizer(f)  # strip optimizers
         if opt.bucket:
             os.system(f'gsutil cp {final} gs://{opt.bucket}/weights')  # upload
     else:
         dist.destroy_process_group()
     torch.cuda.empty_cache()
     return results
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    和模型大小直接挂钩的是这一句:

    strip_optimizer(f)  # strip optimizers
    
    • 1

    这个方法定义在/utils/general.py文件中:

    def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
        # Strip optimizer from 'f' to finalize training, optionally save as 's'
        x = torch.load(f, map_location=torch.device('cpu'))
        if x.get('ema'):
            x['model'] = x['ema']  # replace model with ema
        for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates':  # keys
            x[k] = None
        x['epoch'] = -1
        x['model'].half()  # to FP16
        for p in x['model'].parameters():
            p.requires_grad = False
        torch.save(x, s or f)
        mb = os.path.getsize(s or f) / 1E6  # filesize
        print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    阅读代码,不难发现,这一步,程序将模型文件中的'optimizer', 'training_results', 'wandb_id', 'ema', 'updates'这几个设为None,也就是去除这几个值,同时将模型从FP32转成FP16。

    因此,早停的模型没有经过这个步骤,导致模型精度是FP32,同时包含了大量优化器信息,导致模型过于庞大。

    实验验证

    为了验证答案的正确性,重新来加载模型看看。

    首先加载官方提供的yolov5m.pt模型

    import torch
    
    if __name__ == '__main__':
        ckpt = torch.load('yolov5m.pt')
        print(ckpt)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    输出:

    {'epoch': -1,
    'best_fitness': array([0.45065]), 
    'training_results': None, 
    'model': Model(...)
    'optimizer': None,
    'wandb_id': None
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    可以看到,这个模型文件中,只有best_fitness以及model的结构和参数为有效信息,不包含优化器信息。

    再加载160M的模型:

    import torch
    
    if __name__ == '__main__':
        ckpt = torch.load(r'runs\train\exp\weights\last.pt')
        print(ckpt)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    发现这里输出了大量内容,主要内容是training_results和optimizer,由此可见结论正确。

    {'epoch': 0, 
    'best_fitness': 0.0, 
    'training_results':'....'
    'model': Model(...)
    'updates': 4
    'optimizer': {'state':...}
    ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    同时,也可以发现,模型文件实际上是一个字典,例如,可以用下面的方式获取某层结构或参数信息:

    print(ckpt.model[0].conv.conv)  # 打印某层
    print(ckpt.model[0].conv.conv.state_dict())  # 打印该层参数信息
    
    • 1
    • 2

    模型加载解读

    阅读代码,发现官方在加载模型时,并没有直接torch.load,而是单独写了一个attempt_load函数

    def attempt_load(weights, map_location=None):
        # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
        model = Ensemble()
        for w in weights if isinstance(weights, list) else [weights]:
            attempt_download(w)
            ckpt = torch.load(w, map_location=map_location)  # load
            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
    
        # 适配pytorch不同版本
        for m in model.modules():
            if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
                m.inplace = True  # pytorch 1.7.0 compatibility
            elif type(m) is Conv:
                m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
    
    	# 如果一个模型就直接返回
        if len(model) == 1:
            return model[-1]  # return model
        else:
            print('Ensemble created with %s\n' % weights)
            for k in ['names', 'stride']:
                setattr(model, k, getattr(model[-1], k))
            return model  # return ensemble
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    注意到模型加载完成之后,还有.float().fuse().eval()这样一个操作。

    这三个函数功能如下:

    • float():FP16转换成FP32
    • fuse():将conv和bn层合并,提速模型推理速度
    • eval():eval()是模型进行预测推理时关闭BN(预测数据均值方差计算)和Dropout,从而让结果稳定
      训练过程中,BN会不断计算均值和方差,Dropout比例会使一部分的网络连接不进行计算
      预测过程中,需要让均值和方差稳定不变化,同时会使所有网络连接参与计算
  • 相关阅读:
    GFS 分布式文件系统
    NLP课程笔记-基于transformers的自然语言处理入门
    Mac版eclipse如何安装,运行bpmn文件
    11-包装类
    【Jenkins高级操作,欢迎阅读】
    mysql慢查询日志
    基于java+SpringBoot+VUE+Mysql社区家庭医生服务系统
    性能测试_Day_10(负载测试-获得最大可接受用户并发数)
    数据结构 | 算法的时间复杂度与空间复杂度【通俗易懂】
    【系统和网络软件】上海道宁为您带来适用于Windows的系统和网络软件——MobaXterm与MobaSSH教程
  • 原文地址:https://blog.csdn.net/qq1198768105/article/details/127734182