• yolov7改进优化之蒸馏(二)


    yolov7改进优化之蒸馏(一)-CSDN博客
    上一篇已经基本写出来yolov7/v5蒸馏的整个过程,不过要真的训起来我们还需要进行一些修改。

    Model修改

    蒸馏需要对teacher和student网络的特征层进行loss计算,因此我们forward时要能够返回需要的中间层,这需要修改yolo.py中的Model类。

    forward_once接口修改

    增加接口参数 extra_features用于指定要返回的中间层的索引:

    def forward_once(self, x, profile=False, extra_features: list = []):
    	y, dt = [], []  # outputs
    	features = []
    	for i, m in enumerate(self.model):
    		if m.f != -1:  # if not from previous layer
    			x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
    
    		if not hasattr(self, "traced"):
    			self.traced = False
    
    		if self.traced:
    			if (
    				isinstance(m, Detect)
    				or isinstance(m, IDetect)
    				or isinstance(m, IAuxDetect)
    				or isinstance(m, IKeypoint)
    			):
    				break
    
    		if profile:
    			c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
    			o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1e9 * 2 if thop else 0  # FLOPS
    			for _ in range(10):
    				m(x.copy() if c else x)
    			t = time_synchronized()
    			for _ in range(10):
    				m(x.copy() if c else x)
    			dt.append((time_synchronized() - t) * 100)
    			print("%10.1f%10.0f%10.1fms %-40s" % (o, m.np, dt[-1], m.type))
    
    		x = m(x)  # run
    
    		y.append(x if m.i in self.save else None)  # save output
    
    		if i in extra_features:
    			features.append(x)
    		if not self.training and len(extra_features) != 0 and len(extra_features) == len(features):
    			return x, features
    
    	if profile:
    		print("%.1fms total" % sum(dt))
    	if len(extra_features) != 0:
    		return x, features
    	if self.training and isinstance(x, tuple):
    		x = x[-1]
    	return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46

    主要增加将中间层返回的代码。

    forward接口修改

    forward接口调用了forward_once接口,因此,forward接口也需要增加这个参数。

    def forward(self, x, augment=False, profile=False, extra_features: list = []):
    	if augment:
    		img_size = x.shape[-2:]  # height, width
    		s = [1, 0.83, 0.67]  # scales
    		f = [None, 3, None]  # flips (2-ud, 3-lr)
    		y = []  # outputs
    		for si, fi in zip(s, f):
    			xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
    			yi = self.forward_once(xi)[0]  # forward
    			# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
    			yi[..., :4] /= si  # de-scale
    			if fi == 2:
    				yi[..., 1] = img_size[0] - yi[..., 1]  # de-flip ud
    			elif fi == 3:
    				yi[..., 0] = img_size[1] - yi[..., 0]  # de-flip lr
    			y.append(yi)
    		return torch.cat(y, 1), None  # augmented inference, train
    	else:
    		return self.forward_once(x, profile, extra_features)  # single-scale inference, train
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    hyp文件修改

    在hyp文件中添加student_kd_layers和teacher_kd_layers来指定要蒸馏的层,我们可以指定IDetect前面的三个特征层:

    student_kd_layers: [75,88,101]
    teacher_kd_layers: [75,88,101]
    
    • 1
    • 2

    训练

    训练方式与正常训练一样,只是启动时要指定teacher-weights。

    结语

    这一篇结合上一篇就可以吧基于FGD算法的蒸馏训练起来了,其他蒸馏的修改也大同小异了。
    f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

  • 相关阅读:
    万兆以太网MAC设计(7)ARP协议报文格式详解以及ARP层模块设计
    Mysql中获取所有表名以及表名带时间字符串使用BetweenAnd筛选区间范围
    kubernetes实战入门
    2022-08-24 AndroidR 实现长按按键打开一个app或者打开app的某个界面
    阿里云OSS上传文件超时 探测工具排查方法
    zemax---单透镜设计实例01
    Python编程语言学习:shap.force_plot函数的源码解读之详细攻略
    MySQL单列索引和联合索引
    Docker入门尝鲜
    Python import module package 相关
  • 原文地址:https://blog.csdn.net/liuhao3285/article/details/134000301