• yolov7改进优化之蒸馏(一)


    最近比较忙,有一段时间没更新了,最近yolov7用的比较多,总结一下。上一篇yolov5及yolov7实战之剪枝_CodingInCV的博客-CSDN博客 我们讲了通过剪枝来裁剪我们的模型,达到在精度损失不大的情况下,提高模型速度的目的。上一篇是从速度的角度,这一篇我们从检测性能的角度来改进yolov7(yolov5也类似)。
    对于提高检测器的性能,我们除了可以从增加数据、修改模型结构、修改loss等模型本身的角度出发外,深度学习领域还有一个方式—蒸馏。简单的说,蒸馏就是让性能更强的模型(teacher, 参数量更大)来指导性能更弱student模型,从而提高student模型的性能。
    蒸馏的方式有很多种,比较简单暴力的比如直接让student模型来拟合teacher模型的输出特征图,当然蒸馏也不是万能的,毕竟student模型和teacher模型的参数量有差距,student模型不一定能很好的学习teacher的知识,对于自己的任务有没有作用也需要尝试。
    本篇选择的方法是去年CVPR上的针对目标检测的蒸馏算法:
    yzd-v/FGD: Focal and Global Knowledge Distillation for Detectors (CVPR 2022) (github.com)
    针对该方法的解读可以参考:FGD-CVPR2022:针对目标检测的焦点和全局蒸馏 - 知乎 (zhihu.com)
    本篇暂时不涉及理论,重点在把这个方法集成到yolov7训练。步骤如下。

    载入teacher模型

    蒸馏首先需要有一个teacher模型,这个teacher模型一般和student同样结构,只是参数量更大、层数更多。比如对于yolov5,可以尝试用yolov5m来蒸馏yolov5s
    train.py增加一个命令行参数:

        parser.add_argument("--teacher-weights", type=str, default="", help="initial weights path")
    
    • 1

    在train函数中载入teacher weights,过程与原有的载入过程类似,注意,DP或者DDP模型也要对teacher模型做对应的处理。

    # teacher model
        if opt.teacher_weights:
            teacher_weights = opt.teacher_weights
            # with torch_distributed_zero_first(rank):
            #     teacher_weights = attempt_download(teacher_weights)  # download if not found locally
            teacher_model = Model(teacher_weights, ch=3, nc=nc).to(device)  # create    
            # load state_dict
            ckpt = torch.load(teacher_weights, map_location=device)  # load checkpoint
            state_dict = ckpt["model"].float().state_dict()  # to FP32
            teacher_model.load_state_dict(state_dict, strict=True)  # load
            #set to eval
            teacher_model.eval()
            #set IDetect to train mode
            # teacher_model.model[-1].train()
            logger.info(f"Load teacher model from {teacher_weights}")  # report
    
        # DP mode
        if cuda and rank == -1 and torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
            if opt.teacher_weights:
                teacher_model = torch.nn.DataParallel(teacher_model)
                
    	 # SyncBatchNorm
        if opt.sync_bn and cuda and rank != -1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
            logger.info("Using SyncBatchNorm()")
            if opt.teacher_weights:
    	        teacher_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model).to(device)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    teacher模型不进行梯度计算,因此:

    if opt.teacher_weights:
            for param in teacher_model.parameters():
                param.requires_grad = False
    
    • 1
    • 2
    • 3

    蒸馏Loss

    蒸馏loss是计算teacher模型的一层或者多层与student的对应层的相似度,监督student模型向teacher模型靠近。对于yolov7,可以去监督三个特征层。
    参考FGD的开源代码,我们在loss.py中增加一个FeatureLoss类, 参数暂时使用默认:

    class FeatureLoss(nn.Module):
    
        """PyTorch version of `Feature Distillation for General Detectors`
       
        Args:
            student_channels(int): Number of channels in the student's feature map.
            teacher_channels(int): Number of channels in the teacher's feature map. 
            temp (float, optional): Temperature coefficient. Defaults to 0.5.
            name (str): the loss name of the layer
            alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
            beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
            gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.0005
            lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
        """
        def __init__(self,
                     student_channels,
                     teacher_channels,
                     temp=0.5,
                     alpha_fgd=0.001,
                     beta_fgd=0.0005,
                     gamma_fgd=0.001,
                     lambda_fgd=0.000005,
                     ):
            super(FeatureLoss, self).__init__()
            self.temp = temp
            self.alpha_fgd = alpha_fgd
            self.beta_fgd = beta_fgd
            self.gamma_fgd = gamma_fgd
            self.lambda_fgd = lambda_fgd
        
            if student_channels != teacher_channels:
                self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
            else:
                self.align = None
            
            self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
            self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
            self.channel_add_conv_s = nn.Sequential(
                nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
                nn.LayerNorm([teacher_channels//2, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
            self.channel_add_conv_t = nn.Sequential(
                nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
                nn.LayerNorm([teacher_channels//2, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
    
            self.reset_parameters()
    
        def forward(self,
                    preds_S,
                    preds_T,
                    gt_bboxes,
                    img_metas):
            """Forward function.
            Args:
                preds_S(Tensor): Bs*C*H*W, student's feature map
                preds_T(Tensor): Bs*C*H*W, teacher's feature map
                gt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)
                img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            """
            assert preds_S.shape[-2:] == preds_T.shape[-2:], 'the output dim of teacher and student differ'
            device = gt_bboxes.device
            self.to(device)
            if self.align is not None:
                preds_S = self.align(preds_S)
    
            N,C,H,W = preds_S.shape
    
            S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)
            S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)
            
            Mask_fg = torch.zeros_like(S_attention_t)
            # Mask_bg = torch.ones_like(S_attention_t)
            wmin,wmax,hmin,hmax = [],[],[],[]
            img_h, img_w = img_metas
            bboxes = gt_bboxes[:,2:6]
            #xywh2xyxy
            bboxes = xywh2xyxy(bboxes)
            new_boxxes = torch.ones_like(bboxes)
            new_boxxes[:, 0] = torch.floor(bboxes[:, 0]*W)
            new_boxxes[:, 2] = torch.ceil(bboxes[:, 2]*W)
            new_boxxes[:, 1] = torch.floor(bboxes[:, 1]*H)
            new_boxxes[:, 3] = torch.ceil(bboxes[:, 3]*H)
    
            #to int
            new_boxxes = new_boxxes.int()
    
            for i in range(N):
                new_boxxes_i = new_boxxes[torch.where(gt_bboxes[:,0]==i)]
    
                wmin.append(new_boxxes_i[:, 0])
                wmax.append(new_boxxes_i[:, 2])
                hmin.append(new_boxxes_i[:, 1])
                hmax.append(new_boxxes_i[:, 3])
    
                area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))
    
                for j in range(len(new_boxxes_i)):
                    Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
                            torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])
    
            Mask_bg = torch.where(Mask_fg > 0, 0., 1.)
            Mask_bg_sum = torch.sum(Mask_bg, dim=(1,2))
            Mask_bg[Mask_bg_sum>0] /= Mask_bg_sum[Mask_bg_sum>0].unsqueeze(1).unsqueeze(2)
    
            fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, 
                            C_attention_s, C_attention_t, S_attention_s, S_attention_t)
            mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
            rela_loss = self.get_rela_loss(preds_S, preds_T)
    
            loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
                + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
                
            return loss, loss.detach()
    
        def get_attention(self, preds, temp):
            """ preds: Bs*C*W*H """
            N, C, H, W= preds.shape
    
            value = torch.abs(preds)
            # Bs*W*H
            fea_map = value.mean(axis=1, keepdim=True)
            S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)
    
            # Bs*C
            channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)
            C_attention = C * F.softmax(channel_map/temp, dim=1)
    
            return S_attention, C_attention
    
    
        def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
            loss_mse = nn.MSELoss(reduction='sum')
            
            Mask_fg = Mask_fg.unsqueeze(dim=1)
            Mask_bg = Mask_bg.unsqueeze(dim=1)
    
            C_t = C_t.unsqueeze(dim=-1)
            C_t = C_t.unsqueeze(dim=-1)
    
            S_t = S_t.unsqueeze(dim=1)
    
            fea_t= torch.mul(preds_T, torch.sqrt(S_t))
            fea_t = torch.mul(fea_t, torch.sqrt(C_t))
            fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
            bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
    
            fea_s = torch.mul(preds_S, torch.sqrt(S_t))
            fea_s = torch.mul(fea_s, torch.sqrt(C_t))
            fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
            bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
    
            fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)
            bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)
    
            return fg_loss, bg_loss
    
    
        def get_mask_loss(self, C_s, C_t, S_s, S_t):
    
            mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)
    
            return mask_loss
         
        
        def spatial_pool(self, x, in_type):
            batch, channel, width, height = x.size()
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            if in_type == 0:
                context_mask = self.conv_mask_s(x)
            else:
                context_mask = self.conv_mask_t(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = F.softmax(context_mask, dim=2)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
    
            return context
    
    
        def get_rela_loss(self, preds_S, preds_T):
            loss_mse = nn.MSELoss(reduction='sum')
    
            context_s = self.spatial_pool(preds_S, 0)
            context_t = self.spatial_pool(preds_T, 1)
    
            out_s = preds_S
            out_t = preds_T
    
            channel_add_s = self.channel_add_conv_s(context_s)
            out_s = out_s + channel_add_s
    
            channel_add_t = self.channel_add_conv_t(context_t)
            out_t = out_t + channel_add_t
    
            rela_loss = loss_mse(out_s, out_t)/len(out_s)
            
            return rela_loss
    
    
        def last_zero_init(self, m):
            if isinstance(m, nn.Sequential):
                constant_init(m[-1], val=0)
            else:
                constant_init(m, val=0)
    
        
        def reset_parameters(self):
            kaiming_init(self.conv_mask_s, mode='fan_in')
            kaiming_init(self.conv_mask_t, mode='fan_in')
            self.conv_mask_s.inited = True
            self.conv_mask_t.inited = True
    
            self.last_zero_init(self.channel_add_conv_s)
            self.last_zero_init(self.channel_add_conv_t)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229

    实例化FeatureLoss

    在train.py中,实例化我们定义的FeatureLoss,由于我们要蒸馏三层,所以需要定一个蒸馏损失的数组:

    if opt.teacher_weights:
            student_kd_layers = hyp["student_kd_layers"]
            teacher_kd_layers = hyp["teacher_kd_layers"]
            dump_image = torch.zeros((1, 3, imgsz, imgsz), device=device)
            targets = torch.Tensor([[0, 0, 0, 0, 0, 0]]).to(device)
            _, features = model(dump_image, extra_features = student_kd_layers)  # forward
            _, teacher_features = teacher_model(dump_image,
                                                   extra_features=teacher_kd_layers)
            kd_losses = []
            for i in range(len(features)):
                feature = features[i]
                teacher_feature = teacher_features[i]
                _, student_channels, _ , _ = feature.shape
                _, teacher_channels, _ , _ = teacher_feature.shape
    
                kd_losses.append(FeatureLoss(student_channels,teacher_channels))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    其中hyp[‘xxx_kd_layers’]是用于指定我们要蒸馏的层序号。
    为了提取出我们需要的层的特征图,我们还需要对模型推理的代码进行修改,这个放在下一篇,这一篇先把主要流程过一遍。

    蒸馏训练

    与普通loss一样,在训练中,首先计算蒸馏loss, 然后进行反向传播,区别只是计算蒸馏loss时需要使用teacher模型也对数据进行推理。

    if opt.teacher_weights:
    	pred, features = model(imgs, extra_features = student_kd_layers)  # forward
    	_, teacher_features = teacher_model(imgs, extra_features = teacher_kd_layers)
    	if "loss_ota" not in hyp or hyp["loss_ota"] == 1 and epoch >= ota_start:
    		loss, loss_items = compute_loss_ota(
    			pred, targets.to(device), imgs
    		)
    	else:
    		loss, loss_items = compute_loss(
    			pred, targets.to(device)
    		)  # loss scaled by batch_size
    	# kd loss
    	loss_items = torch.cat((loss_items[0].unsqueeze(0), loss_items[1].unsqueeze(0), loss_items[2].unsqueeze(0), torch.zeros(1, device=device), loss_items[3].unsqueeze(0)))
    	loss_items[-1]*=imgs.shape[0]
    	for i in range(len(features)):
    		feature = features[i]
    		teacher_feature = teacher_features[i]
    
    		kd_loss, kd_loss_item = kd_losses[i](feature, teacher_feature, targets.to(device), [imgsz,imgsz])
    		loss += kd_loss
    		loss_items[3] += kd_loss_item
    		loss_items[4] += kd_loss_item
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    在这里,我们将kd_loss累加到了loss上。计算出总的loss,其他就与普通训练一样了。

    结语

    这篇文章简述了一下yolov7的蒸馏过程,更多细节将在下一篇中讲述。
    f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

  • 相关阅读:
    2022 年 10 月 NFT 报告
    初识设计模式 - 模板方法模式
    MySQL数据库下的Explain命令深度解析
    Flex布局
    <一>对象使用过程中背后调用了哪些方法
    【C++】模板:了解泛型编程
    C++——模板进阶
    科技云报道:走入商业化拐点,大模型“开箱即用”或突破行业困局
    ceph 原理
    CN_广域网WAN@PPP协议
  • 原文地址:https://blog.csdn.net/liuhao3285/article/details/133895411