• (一)STDCNet源码解读



    欢迎访问个人网络日志🌹🌹知行空间🌹🌹


    1.简介

    论文地址:https://arxiv.org/abs/2104.13188

    仓库地址:https://github.com/MichaelFan01/STDC-Seg

    STDCMNet(Short Term Dense Concatenate Network)网络是美团2021年04月27号提交的论文Rethinking BiSeNet For Real-time Semantic Segmentation中提出的轻量级语义分割网络,该网络是在BiSeNet v1/v2基础上的升级改进。STDCNet主要贡献有两点,一方面是对骨干网络backbone的改进,改成了Dense Concatenate的模块结构,同一个STDC模块中,每个ConvX随着感受野的变大输出的通道数逐渐变少,最后再Concatenate到一起,因此包含更多的特征尺度信息。另一方面是多分支低阶细节信息辅助训练结构,detail information guidance结构只在训练的时候使用,网络训练完成后可以直接舍弃,这种方法相对于之前的BiSeNnet可以减少推理时的计算量。

    2.网络结构

    在这里插入图片描述

    如上图,网络的backbone包含5stage,第istage的输出feature map的尺寸是原来的 1 2 i \frac{1}{2^i} 2i1,satge 4&5输出的feature map经过ARM(Attention Refine Module)之后包含更多的语义信息,组成context path,前3stage输出的feature map包含更多的图像细节信息,两者特征融合经SegHead后直接向上最近邻resize输出最终的分割图。Seg Loss使用的是OhemLoss。网络对于低层stage使用Detail Loss做训练,以提升低层stage feature map提取图像细节信息的能力。对于前3个stage输出的feature map使用与SegHead同样结构的Detail Head做处理得到Detail的输出用来计算Detail Loss,**值得注意的是SegHead输出的最终channels数量是分割的类别数,而Detail Head输出的channels数是1,即是边缘的置信度。**计算Detail Loss时,先对ground truthstride=[1,2,4]Laplacian Convolution,将不同size的卷积结果再stack到一起,经过3个可训练的1x1的卷积后得到Detail Ground Truth用来计算Detail Loss。根据源码,从网络输出的角度整理出来的网络结构如下图:

    2.1 Detail Guidance

    在这里插入图片描述

    如上图橙色倒金字塔中表示不同stage卷积输出的feature map,从上到小feature mapsize逐渐变小,channel逐渐变大。在前几个stage输出的feature map尺度更大,包含了更多的图像细节信息,STDCNet的创新之一就是,增加了Detail Guidance Traning分支,训练时对前几个stage输出的特征图计算loss来提升低层卷积对图像细节提取的能力,这一部分如上图中所示,只在训练时有用,在推理时,直接取低层卷积的feature map与包含更多语义信息的高层卷积feature map做融合,相对于BiSeNet减少了推理时的计算量,提升了模型的推理速度。

    Detail Guidance辅助训练可以参考图2,其对stage 1/2/3输出的feature 2/4/8来做训练,提升的是模型低层卷积提取图像细节信息的能力,Detail Ground Truth的生成也可参考图2

    Detail Guidance是对图像边缘做训练故只有2个类,可以使用二分类交叉熵损失函数,如图中所示,Detail Ground Truth中大部分都是黑色的背景,只有少量的表示边缘的像素,因此是严重的类别不平衡问题

    L d e t a i l ( p d , g d ) = L d i c e ( p d , g d ) + L b c e ( p d , g d ) L_{detail}(p_d, g_d)=L_{dice}(p_d, g_d) + L_{bce}(p_d, g_d) Ldetail(pd,gd)=Ldice(pd,gd)+Lbce(pd,gd)
    p d , g d p_d,g_d pd,gd分别表示对应像素位置的值,d表示detail,其中,

    L d i c e = 1 − 2 ∑ i H × W p d i g d i + ϵ ∑ i H × W ( p d i ) 2 + ∑ i H × W ( g d i ) 2 + ϵ L_{dice} = 1 - \frac{2\sum^{H\times W}_{i}p_d^ig_d^i+\epsilon}{\sum^{H\times W}_{i}(p_d^i)^2+\sum^{H\times W}_{i}(g_d^i)^2+\epsilon} Ldice=1iH×W(pdi)2+iH×W(gdi)2+ϵ2iH×Wpdigdi+ϵ
    ϵ \epsilon ϵ是为了防止除0,通常取1dice loss计算参考:

    def dice_loss_func(input, target):
        smooth = 1.
        n = input.size(0)
        iflat = input.view(n, -1)
        tflat = target.view(n, -1)
        intersection = (iflat * tflat).sum(1)
        loss = 1 - ((2. * intersection + smooth) /
                    (iflat.sum(1) + tflat.sum(1) + smooth))
        return loss.mean()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    2.2 Short Term Dense Concatenate

    在这里插入图片描述

    如上图,图a中表示的是网络backbone的整体结构,网络总共分成了6stage,其中前5stage用作分割的backbone,第istage输出的特征图的大小为原来 H × W H\times W H×W 1 2 i \frac{1}{2^i} 2i1,feature map的通道逐渐变大,为 16 × 2 i 16\times 2^i 16×2i,源码中,当stage 2输出的特征图通道数大于64时,会对stage 5的输出增加一个last_conv,只是为了使stage 5输出特征图的通道数不少于1024。图b表示的是每个stage中使用的Short Dense Concatenate Module,从图中可以看到每个STDCModule包括4ConvX Block,且卷积所属层级越高,输出的通道数越少,最后将这些不同卷积的输出再直接Concatenate到一起,论文中有一段介绍,STDC的理由是低层卷积感受野小需要更多的通道来提取细节信息,高层卷积有更大的感受野,只需较小的通道数即可得到足够的语义信息。

    3.其他

    3.1 Attention Refine Module

    Attention Refine ModuleBiSeNet中提出的结构,用于ContextPath中,衡量feature map每个通道上的重要程度,其计算过程是先把输入feature经过kernel=3,stride=1,padding=1的卷积,再通过Global Average Pool处理输出NxCx1x1的张量,后对其经过stride=1,kernel=1,bias=False的卷积和Sigmoid的函数,输出元素值在0-1上的评分张量NxCx1x1,取此张量与原feature相乘,得到最后对每个通道上乘以评分后的输出。

    在这里插入图片描述

    源代码实现:

    class AttentionRefinementModule(nn.Module):
        def __init__(self, in_chan, out_chan, *args, **kwargs):
            super(AttentionRefinementModule, self).__init__()
            self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
            self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
            # self.bn_atten = BatchNorm2d(out_chan)
            self.bn_atten = BatchNorm2d(out_chan, activation='none')
    
            self.sigmoid_atten = nn.Sigmoid()
            self.init_weight()
    
        def forward(self, x):
            feat = self.conv(x)
            atten = F.avg_pool2d(feat, feat.size()[2:])
            atten = self.conv_atten(atten)
            atten = self.bn_atten(atten)
            atten = self.sigmoid_atten(atten)
            out = torch.mul(feat, atten)
            return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    3.2 FeatureFusionModule

    STDCNet中,因其同BiSeNet结构,分成了Spatial PathContext Path,Feature Fusion Module特征融合模块将下采样8倍的Spatial PathContext Path上的feature map融合到一起得到最终的分割效果,使分割结果即包含足够的细节也还能保持好的语义信息。FFM也是BiSeNet中提出的。

    在这里插入图片描述

    class FeatureFusionModule(nn.Module):
        def __init__(self, in_chan, out_chan, *args, **kwargs):
            super(FeatureFusionModule, self).__init__()
            self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
            self.conv1 = nn.Conv2d(out_chan,
                    out_chan//4,
                    kernel_size = 1,
                    stride = 1,
                    padding = 0,
                    bias = False)
            self.conv2 = nn.Conv2d(out_chan//4,
                    out_chan,
                    kernel_size = 1,
                    stride = 1,
                    padding = 0,
                    bias = False)
            self.relu = nn.ReLU(inplace=True)
            self.sigmoid = nn.Sigmoid()
            self.init_weight()
    
        def forward(self, fsp, fcp):
            fcat = torch.cat([fsp, fcp], dim=1)
            feat = self.convblk(fcat)
            atten = F.avg_pool2d(feat, feat.size()[2:])
            atten = self.conv1(atten)
            atten = self.relu(atten)
            atten = self.conv2(atten)
            atten = self.sigmoid(atten)
            feat_atten = torch.mul(feat, atten)
            feat_out = feat_atten + feat
            return feat_out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    可以看到ARMFFM结构上有一定的相似性,都属于通道注意力机制,作者在[知乎]上回复评论时指出,这两部分灵感都是来源于2017年9月份提出的SeNet

    3.3 Global Average Pooling

    GAP,Global Average Pooling,即全局均值池化,就是说,均值池化是作用在整张feature map上的,即输入特征图的shapeNXCXHW,经池化后,输出的shapeNXCX1X1,即池化核的大小是整张特征图,因此称之为全局均值池化,同理理解GMP,Global Maximum Pooling。[GAP]最早是在2013年12月提交的Network in Network论文中提出用来替代全连接层的,具体可以参考这篇博客

    图片来自于博客
    在这里插入图片描述

    代码实现:

    import torch
    import torch.nn.functional as F
    s = torch.randint(0, 255, (1, 1, 4, 4)).type(torch.float)
    print(f"before GAP: {s}")
    avg_s = F.adaptive_avg_pool2d(s, (4, 4))
    print(f"after GAP: {avgs}")
    
    # before GAP: tensor([[[[ 13., 125., 111.,  98.],
    #           [ 77.,  17., 227.,  10.],
    #           [ 54., 253., 252., 118.],
    #           [110.,  33.,  99., 233.]]]])
    # after GAP: tensor([[[[129.4658]]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    其中torch.nn.functional.adaptive_average_pool函数的实现方式参考Question介绍,其原理,

    s t r i d e = i n p u t _ s i z e / / o u t p u t _ s i z e k e r n e l = i n p u t _ s i z e − ( o u t p u t _ s i z e − 1 ) ∗ s t r i d e p a d d i n g = 0

    stride=input_size//output_sizekernel=input_size(output_size1)stridepadding=0" role="presentation" style="position: relative;">stride=input_size//output_sizekernel=input_size(output_size1)stridepadding=0
    stride=input_size//output_sizekernel=input_size(output_size1)stridepadding=0

    3.2 OHEM Loss

    OHEM Loss (Online Hard Example Mining Loss)Focal Loss最初提出都是用来解决检测问题中Positive Proposal BoxesNegative Proposal Boxes类别不平衡问题的,在STDCNet中,对分割输出的训练使用了OHEM LossOHEM Loss在训练过程中,不是使用一个batch中所有的样本来计算损失,而是只使用了损失值较大的一部分样本参与计算损失,这个过程发生在整个训练中,因此是一种online的方法。因其计算loss时会选择损失值大对训练影响大的样本的,因此其能够处理样本不平衡问题。OHEM LossFast RCNN作者Ross Girshick等在2016.04发表的论文Training Region-based Object Detectors with Online Hard Example Mining中提出的。

    源代码定义的OHEM Loss:

    class OhemCELoss(nn.Module):
        def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
            super(OhemCELoss, self).__init__()
            self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
            self.n_min = n_min
            self.ignore_lb = ignore_lb
            self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
    
        def forward(self, logits, labels):
            N, C, H, W = logits.size()
            loss = self.criteria(logits, labels).view(-1)
            loss, _ = torch.sort(loss, descending=True)
            if loss[self.n_min] > self.thresh:
                loss = loss[loss>self.thresh]
            else:
                loss = loss[:self.n_min]
            return torch.mean(loss)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    3.3 卷积感受野

    在机器视觉领域的深度神经网络中有一个概念叫做感受野,用来表示网络内部的不同位置的神经元对原图像的感受范围的大小。通俗的说,感受野就是输入图像对这一层输出的神经元的影响有多大。如以下图片所示,图片来自于博客

    第1层3x3卷积stride=2,RF=3 第2层3x3卷积stride=2,RF=7 黄色feature map对应的感受野是7*7

    感受野Receptive Field的计算公式为: R F i + 1 = k e r n e l + ( R F i − 1 ) × s t r i d e RF_{i+1} = kernel + (RF_i - 1)\times stride RFi+1=kernel+(RFi1)×stride,其中i表示第i个卷积层,kernelstride是当前层的卷积参数。当然,这里没有考虑paddingpoolingdilation,只讨论了普通的卷积。

    3.4 分割效果的评估

    常用的分割效果评价指标有:

    • 像素准确率(Pixel Accuracy,PA),即分类正确的像素数除以总的像素数,同Accuracy
      P A = ∑ i = 1 k p i i ∑ i = 1 k ∑ j = 1 k p i j PA = \frac{\sum_{i=1}^{k}p_{ii}}{\sum_{i=1}^{k}\sum_{j=1}^{k}p_{ij}} PA=i=1kj=1kpiji=1kpii
      其中k表示的是分割分类的类别数, p i j p_{ij} pij表示的是混淆矩阵上ij列上的数目,PAaccuracy

    • 交并比(Intersection of Union, IoU),即ground truthprediction之间计算的比值,
      I o U = T P T P + F P + F N IoU = \frac{TP}{TP+FP+FN} IoU=TP+FP+FNTP
      其中,TP是True Positive,FP是False Positive, FN 是False Negative。常用的指标是mean IoU, mIou是计算各个类别上的IoU求平均所得,同样的有mPA,见博客mIoU没有考虑类别间像素数量差别较大时的情况,对类别不平衡时有可能会失真,可考虑带权重mIoU

    参考STDCNet源码中计算mIoU的代码:

    class MscEval(object):
        def evaluate(self):
            ## evaluate
            n_classes = self.n_classes
            hist = np.zeros((n_classes, n_classes), dtype=np.float32)
            dloader = tqdm(self.dl)
            if dist.is_initialized() and not dist.get_rank()==0:
                dloader = self.dl
            for i, (imgs, label) in enumerate(dloader):
                N, _, H, W = label.shape
                probs = torch.zeros((N, self.n_classes, H, W))
                probs.requires_grad = False
                imgs = imgs.cuda()
                for sc in self.scales:
                    # prob = self.scale_crop_eval(imgs, sc)
                    prob = self.eval_chip(imgs)
                    probs += prob.detach().cpu()
                probs = probs.data.numpy()
                preds = np.argmax(probs, axis=1)
    
                hist_once = self.compute_hist(preds, label.data.numpy().squeeze(1))
                hist = hist + hist_once
            IOUs = np.diag(hist) / (np.sum(hist, axis=0)+np.sum(hist, axis=1)-np.diag(hist))
            mIOU = np.mean(IOUs)
            return mIOU
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    欢迎访问个人网络日志🌹🌹知行空间🌹🌹


    参考资料

  • 相关阅读:
    【Java数据结构】详解LinkedList与链表(二)
    clickhouse学习之路----clickhouse的特点及安装
    mysql数据库基本操作中update修改、delete删除基本操作
    计算机毕业设计Java校园资料在线分享网站(系统+源码+mysql数据库+lw文档)
    MT4和MT5的共同点,anzo capital昂首资本说一个,没人有意见吧
    C++ 标准库类型学习笔记(一)(vector、string 篇)
    Spring中什么样的Bean存在线程安全问题-有状态bean
    Java ClassLoader definePackage()方法具有什么功能呢?
    高级运维学习(九)块存储、文件系统存储和对象存储的实现
    Office365 Excel中使用宏将汉字转拼音
  • 原文地址:https://blog.csdn.net/lx_ros/article/details/126515733