码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 图卷积神经网络层的pytorch复现


    图卷积神经网络层的pytorch复现

            • 基本概念:
            • 图卷积层的数学描述:
            • 图的总体架构:
            • 图卷积层pytorch代码实现[^2]和注释:
            • 参考:


    基本概念:

    图结构非常常见,属于非欧式空间,例如社交网络图、知识图谱、用户点击购买产品产生的关系图、分子结构图、人体关节点连接图。图卷积神经网络算法是一种根据图卷积和神经网络的理论,应用于广泛存在的图结构的实体的算法。图卷积来源于二维卷积,神经网络算法相当于在传统机器学习算法上加上可以学习的权重,使用梯度下降算法更新权重。总的来说,图卷积神经网络是一种结合信号处理和神经网络应用于图结构的一种新算法。具体应用可以对节点、边和整个图进行分类、分割、检测等应用。本文主要记录学习图卷积神经网络的一些理论和想法。

    图卷积层的数学描述:

    图卷积层经过很多的优化和迭代,目前比较主流的一种方法是每一层的复杂度更低,而通过堆叠多层进行更深层次的学习的方法进行学习。具体的推导过程在文献1中,这里省略大篇幅的推导过程。

    多层的图卷积网络按照下面的逐层递推规则:
    A ~ = A + I N D ~ i i = ∑ j A ~ i j \widetilde{A} = A+I_N \\ \widetilde{D}_{ii} = \sum_j{\widetilde{A}_{ij}} A =A+IN​D ii​=j∑​A ij​

    H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma(\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)}) H(l+1)=σ(D −21​A D −21​H(l)W(l))

    式子中的 A A A指的是图的邻接矩阵形式, W ( l ) W^{(l)} W(l)指的是可学习的权重, H H H是图节点的最初的特征矩阵 H 0 H^{0} H0经过每一层变换后的矩阵, σ ( ) \sigma() σ()指的是激活函数。邻接矩阵和拉普拉斯矩阵可以参考2。

    图的总体架构:

    图的总体架构如下所示,本篇文章需要实现的就是里面的hidden layers,GraphConvolutionLayer不改变图的结构,所以图结构进过图卷积神经网络层后仍然保持原来的结构。但是后面层的节点能够聚合前面层的节点信息。类似于卷积神经网络的“视野”的概念。深层得到更多的语义信息,浅层则保留更多的原始特征信息。

    在这里插入图片描述

    图卷积层pytorch代码实现3和注释:
    # -*- coding: utf-8 -*-
    # # @Use     : Paper reproduction
    # # @Time    : 2022/8/11 21:30
    # # @FileName: GraphConvolutionLayer.py
    # # @Software: PyCharm
    # # @Paper   : Spectral Networks and Locally Connected Networks on Graphs
    
    
    import torch
    import torch.nn as nn
    
    
    class GraphConvolutionLayer(nn.Module):
        """
        图卷积神经网络
        """
    
        def __init__(self, input_dim, output_dim, adjacency_matrix=None, use_bias=True):
            super(GraphConvolutionLayer, self).__init__()
    
            self.input_dim = input_dim
            self.output_dim = output_dim
            self.use_bias = use_bias
            data = torch.tensor(input_dim, output_dim)
            self.weight = nn.Parameter(data=data)
            if self.use_bias:
                self.bias = nn.Parameter(torch.tensor(input_dim, output_dim))
            else:
                self.register_parameter('bias', None)
            self.reset_parameters()
            self.L_matrix = self.calculate_L_matrix(adjacency_matrix)
    
        def reset_parameters(self):
            """
            重置权重
            """
            nn.init.kaiming_normal_(self.weight)
            if self.use_bias:
                nn.init.zeros_(self.bias)
    
        def forward(self, input_feature):
            """
            邻接矩阵是稀疏矩阵,使用稀疏矩阵的乘法
            @param input_feature:输入特征
            """
            # 计算图卷积的输出
            # (\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})
            suport = torch.mm(input_feature, self.weight)
            output = torch.sparse.mm(self.L_matrix, suport)  # 注意因为邻接矩阵是稀疏矩阵,所以使用稀疏矩阵乘法提高效率
            if self.use_bias:
                output += self.bias
            return output
    
        @staticmethod
        def calculate_L_matrix(adjcency: torch.Tensor) -> torch.Tensor:
            """
            根据图的邻接矩阵计算矩阵L_matrix
            L_matrix = \widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}
            """
            dim = adjcency.shape[0]
            A_ware = adjcency + torch.eye(dim)  # 生成单位矩阵
            D_ii = torch.flatten(torch.sum(A_ware, dim=0))  # 按照列进行求和,并且展平成一维向量
            D_ware = torch.diag_embed(D_ii)  # 转换成对角矩阵
            D_ware_temp = torch.pow(D_ware, -0.5)  # 求对角阵的-1/2指数
            L_matrix = torch.mm(torch.mm(D_ware_temp, A_ware), D_ware_temp)  # 使用广播机制进行矩阵乘法
            return L_matrix
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    参考:

    其他参考文献和代码:

    Spectral Networks and Locally Connected Networks on Graphs

    Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering

    mdeff/cnn_graph: Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering (github.com)

    lutional Neural Networks on Graphs with Fast Localized Spectral Filtering (github.com)](https://github.com/mdeff/cnn_graph)


    1. SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS ↩︎

    2. 图的拉普拉斯矩阵_KPer_Yang的博客-CSDN博客 ↩︎

    3. tkipf/gcn: Implementation of Graph Convolutional Networks in TensorFlow (github.com) ↩︎

  • 相关阅读:
    安培龙冲刺创业板上市:收入依赖美的,邬若军、黎莉夫妇为实控人
    Echart前端的修饰器,你不来看看吗?
    HK32F030MF4P6 SWD管脚功能复用GPIO
    c++ 中 拷贝构造函数 和 operator= 函数 的使用区别
    【我不是熟悉的javascript】使用postMessage+iframe实现授权登录
    好用的问卷工具有什么?5款工具盘点
    Java基于springboot+vue的五金用品销售购物商城系统 前后端分离
    企业微信怎么变更企业名称?
    PEDOT:PSS/甘油酸胆碱([Ch][Glyce])离子液体混合材料
    域控操作三点五:使用策略下发将域用户添加到本地管理员组
  • 原文地址:https://blog.csdn.net/KPer_Yang/article/details/126326770
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号