• python:神经网络的卷积核,权重矩阵长什么样子?


    很多人在做深度学习的时候,都把神经网络看成了一个黑盒子,只管输入不同的训练样本和标签,就可以预测出来和训练样本标签相似的结果。想必不少人也研究过神经网络的计算过程,在研究中一定会学到梯度下降算法和损失函数,也一定会了解到卷积核和反向传播求导等概念。那么你一定会好奇我们训练出来的模型到底是什么?它长什么样子?卷积核长什么样子吧?本文通过python代码读取训练好的 PointNet 神经网络模型让你看看黑盒子到底长什么样?


    首先,我们使用torch模块,读取本地的.pth文件。

    ①打印模型,我们可以看到.pth保存的是一个字典。

    import torch
    model_dir='D:\\PointNet\\best_model.pth'
    checkpoint = torch.load(model_dir)
    print(checkpoint)
    
    • 1
    • 2
    • 3
    • 4

    ②打印模型字典的keys,关键词。

    import torch
    model_dir='D:\\PointNet\\best_model.pth'
    checkpoint = torch.load(model_dir)
    for dic in checkpoint:
        print(dic)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述
    可以看到模型保存了训练迭代次数epoch,分类精度iou,模型参数(卷积核等)字典,优化函数参数。
    ③打印epoch。

    print(checkpoint['epoch'])
    
    • 1

    在这里插入图片描述
    ④打印分类精度。

    print(checkpoint['class_avg_iou'])
    
    • 1

    在这里插入图片描述
    ⑤打印参数字典(卷积核等)

    littleDic = checkpoint['model_state_dict']
    for dic in littleDic:
        print(dic)
    
    • 1
    • 2
    • 3

    在这里插入图片描述
    ⑥打印conv1的卷积核,形状

    a = checkpoint['model_state_dict']
    print(a['feat.stn.conv1.weight'], "\n\n", a['feat.stn.conv1.weight'].shape)
    
    • 1
    • 2

    在这里插入图片描述
    在这里插入图片描述

    这里可以看到卷积核的大小为64页9行1列,即这一层网络有64个卷积核,这就是PointNet中点云的特征列数从3列变成64列的原因。

    ⑦打印优化方法参数

    littledic = checkpoint['optimizer_state_dict']
    for dic in littledic:
        print(dic)
    
    • 1
    • 2
    • 3

    在这里插入图片描述
    打印state

    littledic = checkpoint['optimizer_state_dict']['state']
    for dic in littledic:
        print(dic)
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    print(checkpoint['optimizer_state_dict'])
    
    • 1

    在这里插入图片描述
    在这里插入图片描述


    通过以上操作,我们可以看出来,卷积核就是矩阵,神经网络就是利用梯度下降法去拟合最佳的卷积核,卷积核和样本数据的特征矩阵相乘就是预测结果。所以神经网络是可以解释的数学方法。所以你懂黑盒子了吗?

  • 相关阅读:
    ZZNUOJ_C语言1013:求两点间距离(完整代码)
    Docker:部署微服务集群
    [学习笔记]CS224W
    Compose 组件 - 分页器 HorizontalPager、VerticalPager
    vue3新特性v-bind in CSS
    mybatisPlus的简单使用
    华为云服务器安装Linux并实现本地连接访问
    promise实现koa2洋葱中间件模型
    有关cache的dirty比特位和Valid比特位的理解
    Qt day04
  • 原文地址:https://blog.csdn.net/qq_35591253/article/details/127671790