• 【DETR源码解析】二、Backbone模块


    前言

    最近在看DETR的源码,断断续续看了一星期左右,把主要的模型代码理清了。一直在考虑以什么样的形式写一写DETR的源码解析。考虑的一种形式是像之前写的YOLOv5那样的按文件逐行写,一种是想把源码按功能模块串起来。考虑了很久还是决定按第二种方式,一是因为这种方式可能会更省时间,另外就是也方便我整体再理解一下吧。

    我觉得看代码就是要看到能把整个模型分功能拆开,最后再把所有模块串起来,这样才能达到事半功倍。

    另外一点我觉得很重要的是:拿到一个开源项目代码,要有马上配置环境能够正常运行Debug,并且通过解析train.py马上找到主要模型相关的内容,然后着重关注模型方面的解析,像一些日志、计算mAP、画图等等代码,完全可以不看,可以省很多时间,所以以后我讲解源码都会把无关的代码完全剥离,不再讲解,全部精力关注模型、改进、损失等内容。

    这一节主要讲一下DETR的Backbone部分,包括CNN和位置编码两个模块的代码。主要涉及models/backbone.py和models/position_encoding.py两个文件。

    Github注释版源码:HuKai97/detr-annotations

    一、Backbone整体结构

    整个Backbone主要包括CNN特征提取和位置编码两个部分。代码还是比较简单的,下面开始解析源码。

    首先是调用models/Backbone.py中的build_backbone函数创建Backbone:

    def build_backbone(args):
        # 搭建backbone
        # 位置编码  PositionEmbeddingSine()
        position_embedding = build_position_encoding(args)
        train_backbone = args.lr_backbone > 0   # 是否需要训练backbone  True
        return_interm_layers = args.masks       # 是否需要返回中间层结果 目标检测False  分割True
        # 生成backbone  resnet50
        backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
        # 将backbone输出与位置编码相加   0: backbone   1: PositionEmbeddingSine()
        model = Joiner(backbone, position_embedding)
        model.num_channels = backbone.num_channels   # 512
        return model
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    这里首先调用build_position_encoding函数生成正余弦位置编码position_embedding:[bs,256,H/32, W/32],其中256前128是y方向位置编码,后128是x方向位置编码;再调用Backbone类生成ResNet50对输入数据进行特征提取得到特征图[bs,2048,H/32, W/32]。最后Joiner将两者合并存储起来,方便后续使用。

    一、CNN-Backbone

    创建ResNet50,先调用Backbone类:

    class Backbone(BackboneBase):
        """ResNet backbone with frozen BatchNorm."""
        def __init__(self, name: str,
                     train_backbone: bool,
                     return_interm_layers: bool,
                     dilation: bool):
            # 直接掉包 调用torchvision.models中的backbone
            backbone = getattr(torchvision.models, name)(
                replace_stride_with_dilation=[False, False, dilation],
                pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
            # resnet50  2048
            num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
            super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    这个类是继承自BackboneBase类的,而且CNN直接调用的就是torchvision.models中的模型,所以直接看BackboneBase类:

    class BackboneBase(nn.Module):
    
        def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
            super().__init__()
            for name, parameter in backbone.named_parameters():
                # layer0 layer1不需要训练 因为前面层提取的信息其实很有限 都是差不多的 不需要训练
                if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                    parameter.requires_grad_(False)
            # False 检测任务不需要返回中间层
            if return_interm_layers:
                return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
            else:
                return_layers = {'layer4': "0"}
            # 检测任务直接返回layer4即可  执行torchvision.models._utils.IntermediateLayerGetter这个函数可以直接返回对应层的输出结果
            self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
            self.num_channels = num_channels
    
        def forward(self, tensor_list: NestedTensor):
            """
            tensor_list: pad预处理之后的图像信息
            tensor_list.tensors: [bs, 3, 608, 810]预处理后的图片数据 对于小图片而言多余部分用0填充
            tensor_list.mask: [bs, 608, 810] 用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)
            """
            # 取出预处理后的图片数据 [bs, 3, 608, 810] 输入模型中  输出layer4的输出结果 dict '0'=[bs, 2048, 19, 26]
            xs = self.body(tensor_list.tensors)
            # 保存输出数据
            out: Dict[str, NestedTensor] = {}
            for name, x in xs.items():
                m = tensor_list.mask  # 取出图片的mask [bs, 608, 810] 知道图片哪些区域是有效的 哪些位置是pad之后的无效的
                assert m is not None
                # 通过插值函数知道卷积后的特征的mask  知道卷积后的特征哪些是有效的  哪些是无效的
                # 因为之前图片输入网络是整个图片都卷积计算的 生成的新特征其中有很多区域都是无效的
                mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
                # out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]
                out[name] = NestedTensor(x, mask)
            # out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]
            return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    这个类还是在调用torchvision.models中的模型,然后再把预处理后的图片数据[bs, 3, 608, 810]和mask数据[bs, 608, 810]输入到模型中(这个图片数据是经过pad填充的数据,而mask数据就是记录这些图片哪些像素位置是pad的,为True,没用pad的真实有效数据就为False)。经过前向传播,再调用IntermediateLayerGetter函数把对应层特征图提取出来,得到原图32倍下采样的特征图[bs, 2048, 19, 26],以及这张特征图对应的mask[bs, 19, 26]。

    二、Positional Encoding

    Positional Encoding 就是位置编码。这里主要是调用models/position_encoding.py中的build_position_encoding函数创建位置编码:

    def build_position_encoding(args):
        """
        创建位置编码
        args: 一系列参数  args.hidden_dim: transformer中隐藏层的维度   args.position_embedding: 位置编码类型 正余弦sine or 可学习learned
        """
        # N_steps = 128 = 256 // 2  backbone输出[bs,256,25,34]  256维度的特征
        # 而传统的位置编码应该也是256维度的, 但是detr用的是一个x方向和y方向的位置编码concat的位置编码方式  这里和ViT有所不同
        # 二维位置编码   前128维代表x方向位置编码  后128维代表y方向位置编码
        N_steps = args.hidden_dim // 2
        if args.position_embedding in ('v2', 'sine'):
            # TODO find a better way of exposing other arguments
            # [bs,256,19,26]  dim=1时  前128个是y方向位置编码  后128个是x方向位置编码
            position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
        elif args.position_embedding in ('v3', 'learned'):
            position_embedding = PositionEmbeddingLearned(N_steps)
        else:
            raise ValueError(f"not supported {args.position_embedding}")
    
        return position_embedding
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    可以看到,源码是实现了两种位置编码,一种是正余弦绝对位置编码,不需要额外的参数学习,另一种是可学习绝对位置编码。原论文用的是正余弦绝对位置编码,而且代码也是默认使用这个的,所以这里主要介绍PositionEmbeddingSine类:

    class PositionEmbeddingSine(nn.Module):
        """
        Absolute pos embedding, Sine.  没用可学习参数  不可学习  定义好了就固定了
        This is a more standard version of the position embedding, very similar to the one
        used by the Attention is all you need paper, generalized to work on images.
        """
        def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
            super().__init__()
            self.num_pos_feats = num_pos_feats    # 128维度  x/y  = d_model/2
            self.temperature = temperature        # 常数 正余弦位置编码公式里面的10000
            self.normalize = normalize            # 是否对向量进行max规范化   True
            if scale is not None and normalize is False:
                raise ValueError("normalize should be True if scale is passed")
            if scale is None:
                # 这里之所以规范化到2*pi  因为位置编码函数的周期是[2pi, 20000pi]
                scale = 2 * math.pi  # 规范化参数 2*pi
            self.scale = scale
    
        def forward(self, tensor_list: NestedTensor):
            x = tensor_list.tensors   # [bs, 2048, 19, 26]  预处理后的 经过backbone 32倍下采样之后的数据  对于小图片而言多余部分用0填充
            mask = tensor_list.mask   # [bs, 19, 26]  用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)
            assert mask is not None
            not_mask = ~mask   # True的位置才是真实有效的位置
    
            # 考虑到图像本身是2维的 所以这里使用的是2维的正余弦位置编码
            # 这样各行/列都映射到不同的值 当然有效位置是正常值 无效位置会有重复值 但是后续计算注意力权重会忽略这部分的
            # 而且最后一个数字就是有效位置的总和,方便max规范化
            # 计算此时y方向上的坐标  [bs, 19, 26]
            y_embed = not_mask.cumsum(1, dtype=torch.float32)
            # 计算此时x方向的坐标    [bs, 19, 26]
            x_embed = not_mask.cumsum(2, dtype=torch.float32)
    
            # 最大值规范化 除以最大值 再乘以2*pi 最终把坐标规范化到0-2pi之间
            if self.normalize:
                eps = 1e-6
                y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
                x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
    
            dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)   # 0 1 2 .. 127
            # 2i/2i+1: 2 * (dim_t // 2)  self.temperature=10000   self.num_pos_feats = d/2
            dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)   # 分母
    
            pos_x = x_embed[:, :, :, None] / dim_t   # 正余弦括号里面的公式
            pos_y = y_embed[:, :, :, None] / dim_t   # 正余弦括号里面的公式
            # x方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]
            pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
            # y方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]
            pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
            # concat: [bs,19,26,128][bs,19,26,128] -> [bs,19,26,256] -> [bs,256,19,26]
            pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
    
            # [bs,256,19,26]  dim=1时  前128个是y方向位置编码  后128个是x方向位置编码
            return pos
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53

    对照公式:
    在这里插入图片描述
    我的几个关键点的理解:

    1. 这里是通过mask来构建位置编码的,mask中记录了特征图中每个像素位置是否是pad的,只有在为False的位置,才是有效的位置,才需要构建位置编码;
    2. 关于最大值规范化 :因为正余弦编码方式,思想就是将各个位置的通过公式映射到 0~2Π 这个范围内(也可以是4Π,6Π,8Π…,因为它是一个周期函数,不过我们一般默认为2Π),所以这里在带入公式之前需要对x_embed、y_embed先进行规范化;
    3. 关于位置编码方式:这里之所以是把x和y分别进行位置编码(二维位置编码),而不是像transformer那样的一维位置编码。主要考虑的是transformer是应用在语言模型中的,天然就是一维的,所以一维可能更适合,而DETR是应用在图像任务中的一个目标检测框架,在图像任务中,当然二维位置编码效果可能会更好点;
    4. 这样,对于每个位置(x,y),其所在列对应的编码值就在通道维度的前128维,其所在行的编码值就在通道这个维度的后128维。这样这个特征图上各个位置就都对应到不同的维度的编码值了。

    当然作为学习,也可以看看第二种绝对位置编码方式:可学习位置编码:

    class PositionEmbeddingLearned(nn.Module):
        """
        Absolute pos embedding, learned.
        可以发现整个类其实就是初始化了相应shape的位置编码参数,让后通过可学习的方式学习这些位置编码参数
        """
        def __init__(self, num_pos_feats=256):
            super().__init__()
            # nn.Embedding  相当于 nn.Parameter  其实就是初始化函数
            self.row_embed = nn.Embedding(50, num_pos_feats)
            self.col_embed = nn.Embedding(50, num_pos_feats)
            self.reset_parameters()
    
        def reset_parameters(self):
            nn.init.uniform_(self.row_embed.weight)
            nn.init.uniform_(self.col_embed.weight)
    
        def forward(self, tensor_list: NestedTensor):
            x = tensor_list.tensors
            h, w = x.shape[-2:]   # 特征图h w
            i = torch.arange(w, device=x.device)
            j = torch.arange(h, device=x.device)
            x_emb = self.col_embed(i)   # 初始化x方向位置编码
            y_emb = self.row_embed(j)   # 初始化y方向位置编码
            # concat x y 方向位置编码
            pos = torch.cat([
                x_emb.unsqueeze(0).repeat(h, 1, 1),
                y_emb.unsqueeze(1).repeat(1, w, 1),
            ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
            return pos
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    可以发现整个类其实就是初始化了相应shape的位置编码参数,然后通过可学习的方式自己学习这些位置编码参数,代码比较简答。

    Reference

    官方源码: https://github.com/facebookresearch/detr

    b站源码讲解: 铁打的流水线工人

    知乎【布尔佛洛哥哥】: DETR 源码解读

    CSDN【在努力的松鼠】源码讲解: DETR源码笔记(一)

    CSDN【在努力的松鼠】源码讲解: DETR源码笔记(二)

    CSDN: Transformer中的position encoding(位置编码一)

    知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(一)、概述与模型推断】

    知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理】

    知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码】

    知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(四)、Detection with Transformer】

    知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法】

    知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(六)、模型输出与预测生成】

  • 相关阅读:
    Mysql数据库
    Word控件Spire.Doc 【图像形状】教程(1) ;如何在 Word 中插入图像(C#/VB.NET)
    Kubernetes API的流编解码器StreamSerializer
    中文语音识别转文字的王者,阿里达摩院FunAsr足可与Whisper相颉顽
    资源管理平台头部导航栏(1+X Web前端开发初级 例题)
    MySQL 百万级/千万级表 全量更新
    linux的ls命令
    芯片启动以及boot
    CMake中target_link_libraries的使用
    基础算法篇——位运算
  • 原文地址:https://blog.csdn.net/qq_38253797/article/details/127614228