• 卷积核、特征图可视化



    最近在读AlexNet这篇CNN的开山之作,里面有卷积层卷积核可视化这一部分,故记录一下,其他网络也可照猫画虎,希望得到一个宝贵的赞。

    一、卷积核可视化

    1、准备一个训练好的模型

    这里建议使用一个训练好的模型进行可视化,这样可视化出来的结果可以帮助观察出一些潜在的特性的(我这里采用的是AlexNet预训练模型)。

    2、卷积核可视化

    这里简单进行一个梳理
    第一层卷积核:torch.Size([64, 3, 11, 11]),
    输出通道数:64, 对应卷积核的数量
    输入通道数:3, 对应卷积核的通道数
    卷积核宽:11,
    卷积核高:11
    单通道卷积核可视化多通道卷积核可视化见下图:

    代码如下:

    import os
    import torch
    import torch.nn as nn
    from torch.utils.tensorboard import SummaryWriter
    import torchvision.utils as vutils
    import torchvision.models as models
    
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    log_dir = os.path.join(BASE_DIR, "results")
    writer = SummaryWriter(log_dir=log_dir, filename_suffix="_kernel")
    
    path_state_dict = os.path.join("data", "alexnet-owt-4df8aa71.pth")
    alexnet = models.alexnet()
    alexnet.load_state_dict(torch.load(path_state_dict))
    
    kernel_num = -1
    vis_max = 1
    
    # 取前两层卷积核
    for sub_module in alexnet.modules():
        if not isinstance(sub_module, nn.Conv2d):
            continue
        if kernel_num >= vis_max:
            break
        kernel_num += 1
        kernels = sub_module.weight
        c_out, c_int, k_h, k_w = tuple(kernels.shape)  # 输出通道数,输入通道数,卷积核宽,卷积核高
        print(kernels.shape)
        for o_idx in range(c_out):
            kernel_idx = kernels[o_idx, :, :, :].unsqueeze(1)  # 获得(3, h, w), 但是make_grid需要 BCHW,这里拓展C维度变为(3, 1, h, w)
            kernel_grid = vutils.make_grid(kernel_idx, normalize=True, scale_each=True, nrow=8)  # 将卷积核于网格中可视化
            # nrow:每一行显示的图像数
            writer.add_image('{}_Convlayer_split_in_channel'.format(kernel_num), kernel_grid, global_step=o_idx)
        #     名称,图片,第几张图片
        kernel_all = kernels.view(-1, 3, k_h, k_w)  # 3, h, w
        kernel_grid = vutils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=8)  # c, h, w
        writer.add_image('{}_all'.format(kernel_num), kernel_grid, global_step=620)
        print("{}_convlayer shape:{}".format(kernel_num, tuple(kernels.shape)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    补充操作
    运行代码后会生成一个results文件,接下来
    执行如下代码:tensorboard --logdir= 所存储的results文件路径)
    点击如下链接即可跳转到可视化界面

    结果如下图所示

    二、特征图可视化

    这里只展示第一层,输入一张图片经过第一层卷积后所得特征层的可视化
    代码如下:

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    log_dir = os.path.join(BASE_DIR, "results")
    
    path_state_dict = os.path.join("data", "alexnet-owt-4df8aa71.pth")
    alexnet = models.alexnet()
    alexnet.load_state_dict(torch.load(path_state_dict))
    writer = SummaryWriter(log_dir=log_dir, filename_suffix="_feature map")
    
    # 数据
    path_img = os.path.join(BASE_DIR,"data", "tiger cat.jpg")  # your path to image
    img_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768])
    ])
    
    img_pil = Image.open(path_img).convert('RGB')
    img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)  # chw -> bchw
    
    
    convlayer1 = alexnet.features[0]  # 第一层卷积层
    fmap_1 = convlayer1(img_tensor)
    
    fmap_1.transpose_(0, 1)  # bchw=(1, 64, 55, 55) --> (64, 1, 55, 55)
    fmap_1_grid = vutils.make_grid(fmap_1, normalize=True, scale_each=True, nrow=8)
    
    writer.add_image('feature map in conv1', fmap_1_grid, global_step=620)
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    下面是原图和效果图:

    如有不足还望指出,最后再次请求一个小小的赞。

  • 相关阅读:
    【Spring Boot】
    Java EE -- Spring
    定制qga(作业截图)
    C++11 lambda+包装器+可变参数模板
    [springMVC]9、处理json数据(@RequestBody,@ResponseBody)
    国家域名后缀有哪些?
    联合投稿其乐融融 抖音共创助你大显身手
    小白学java
    SaaSBase:Blue Prism是什么?
    webpack实战:最新QQ音乐sign参数加密分析
  • 原文地址:https://blog.csdn.net/hjkdh/article/details/125357950