• Centernet 生成高斯热图


    写在前面的话

    最近学校阳了,宿舍给封了,宿舍网络不好远程跑不了实验,随缘写一下对CenterNet源码的一个解读,之前写论文的那段时间留下来的工作,respect!

    这个文章主要是对CenterNet中生成高斯核的部分代码进行解析,具体原理不会细讲,但是本文增加了一个很方便理解的可视化的代码,可以自己拿来跑就行,自己debug应该也可以理解作者的意思,希望对读者有帮助。

    可视化代码下载链接:https://download.csdn.net/download/weixin_42899627/87157112

    Centernet 源码位置
    本文核心代码在CenterNet/src/lib/utils/image.py中可以找到

    二维高斯函数的公式

    在这里插入图片描述

    CenterNet源码中二维高斯函数实现如下:

    tip: 对比公式少了些东西,但是不影响高斯函数的特性,这里关键还是看高斯核半径的计算

    def gaussian2D(shape, sigma=1):
        m, n = [(ss - 1.) / 2. for ss in shape]
        y, x = np.ogrid[-m:m + 1, -n:n + 1]#np.orgin 生成二维网格坐标
    
        h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
        h[h < np.finfo(h.dtype).eps * h.max()] = 0 #np.finfo()常用于生成一定格式,数值较小的偏置项eps,以避免分母或对数变量为零
        return h
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    高斯核半径的计算

    从代码上看就是一元二次方程的求根公式

    这里要注意的代码中计算高斯半径是根据框的角点进行计算,而在Centernet中需要计算的是框的中心点的高斯半径,其实道理是一样的 Centernet 框的角点的偏移可以近似对于框中心点的偏移

    情况一:两角点均在真值框内
    情况二:两角点均在真值框外
    情况三:一角点在真值框内,一角点在真值框外

    参考文章:
    CornerNet Guassian radius高斯半径的确定-数学公式详解
    说点Cornernet/Centernet代码里面GT heatmap里面如何应用高斯散射核

    def gaussian_radius(det_size, min_overlap=0.7):
        height, width = det_size
    
        a1 = 1
        b1 = (height + width)
        c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
        sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
        r1 = (b1 + sq1) / 2
    
        a2 = 4
        b2 = 2 * (height + width)
        c2 = (1 - min_overlap) * width * height
        sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
        r2 = (b2 + sq2) / 2
    
        a3 = 4 * min_overlap
        b3 = -2 * min_overlap * (height + width)
        c3 = (min_overlap - 1) * width * height
        sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
        r3 = (b3 + sq3) / 2
        return min(r1, r2, r3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    CenterNet源码中 draw_umich_gaussian 函数实现如下:

    tip: 没啥特别的操作,主要是将生成的一个二维高斯核(目标框尺寸)放到原图(图像尺寸)的对应位置上

    def draw_umich_gaussian(heatmap, center, radius, k=1):
        diameter = 2 * radius + 1
        gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
    
        x, y = int(center[0]), int(center[1])
    
        height, width = heatmap.shape[0:2]
    
        left, right = min(x, radius), min(width - x, radius + 1)
        top, bottom = min(y, radius), min(height - y, radius + 1)
    
        masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
        masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
        if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:  # TODO debug
            np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)#逐个元素比较大小,保留大的值
        return heatmap
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述

    import numpy as np
    import math
    import xml.etree.ElementTree as ET
    import glob
    from image import draw_dense_reg, draw_msra_gaussian, draw_umich_gaussian
    from image import get_affine_transform, affine_transform, gaussian_radius
    
    data_dir = r"*.jpg"
    a_file = glob.glob(data_dir)[0]
    print(a_file, a_file.replace(".jpg", ".xml"))
    
    tree = ET.parse(a_file.replace(".jpg", ".xml"))
    root = tree.getroot()
    size = root.find('size')
    width = int(size.find('width').text)
    height = int(size.find('height').text)
    print(f"原图宽:{width} 高:{height}")
    
    num_classes = 3
    output_h = height
    output_w = width
    hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
    
    anns = []
    for obj in root.iter('object'):
        bbox = obj.find('bndbox')
        cate = obj.find('name').text
        # print(cate, bbox.find("xmin").text, bbox.find("xmax").text,
        #       bbox.find("ymin").text, bbox.find("ymax").text)
        xyxy = [int(bbox.find("xmin").text), int(bbox.find("ymin").text),
              int(bbox.find("xmax").text),int(bbox.find("ymax").text)]
        anns.append({"bbox" : xyxy,'category_id':int(cate)})
    
    num_objs = len(anns)
    flipped = False #是否经过全图翻转
    
    import matplotlib.pyplot as plt
    plt.figure(figsize=(19, 6))
    plt.ion()
    plt.subplot(131)
    img = plt.imread(a_file)
    plt.title('Origin_img')
    plt.imshow(img)
    
    for k in range(num_objs):
        ann = anns[k]
        bbox = ann['bbox']
        cls_id = ann['category_id']
        if flipped:
            bbox[[0, 2]] = width - bbox[[2, 0]] - 1
        # bbox[:2] = affine_transform(bbox[:2], trans_output)# 仿射变换
        # bbox[2:] = affine_transform(bbox[2:], trans_output)
        # bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)#裁剪
        # bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
        h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
        if h > 0 and w > 0:
            radius = gaussian_radius((math.ceil(h), math.ceil(w)))
            radius = max(0, int(radius))
            # radius = self.opt.hm_gauss if self.opt.mse_loss else radius
            ct = np.array(
                [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
            ct_int = ct.astype(np.int32)
            plt.subplot(133)
            hm_out, gaussian = draw_umich_gaussian(hm[cls_id], ct_int, radius)
            plt.title('Umich Heatmap')
            # hm_out = draw_msra_gaussian(hm[cls_id], ct_int, radius)
            # print(hm_out.shape)
            # plt.title("Mara Heatmap")
            plt.text(ct[0], ct[1], f"(class:{cls_id})", c='white')
            plt.plot([bbox[0], bbox[2], bbox[2], bbox[0], bbox[0]], [bbox[1], bbox[1], bbox[3], bbox[3], bbox[1]])
            plt.imshow(hm_out)
            plt.subplot(132)
            plt.title(f'Gaussian: bbox_h={h},bbox_w={w}, radius={radius}')
            plt.imshow(gaussian)
            plt.pause(2)
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77

    在这里插入图片描述

    参考文章

    1. np.ogrid & np.mgrid 用法
    2. 一维和二维高斯函数及其一阶和二阶导数

  • 相关阅读:
    微服务多模块项目maven打包时报找不到依赖模块中的类
    电商数仓整体理解
    企业进行高质量数据管理,实施数据治理的关键是什么?
    蓝桥杯第四场双周赛(1~6)
    8.11 DAy39---MyBatis面试题
    【力扣白嫖日记】601.体育馆的人流量
    VSCode 在部分 Linux 设备上终端和文本编辑器显示文本不正常的解决方法
    linux服务器配置openssl
    数据科学AB测试(说人话系列)
    景联文科技:打造亿级高质量教育题库,赋能教育大语言模型新未来
  • 原文地址:https://blog.csdn.net/weixin_42899627/article/details/128042986