• 模型训练出现NAN


    【功能模块】完整代码在附件,数据集需要的话也可以提供

    class EmbeddingImagenet(nn.Cell):
        def __init__(self,emb_size,cifar_flag=False):
            super(EmbeddingImagenet, self).__init__()
            # set size
            self.hidden = 64
            self.last_hidden = self.hidden * 25 if not cifar_flag else self.hidden * 4
            self.emb_size = emb_size
            self.out_dim = emb_size
    
            # set layers
            self.conv_1 = nn.SequentialCell(nn.Conv2d(in_channels=3,
                                                  out_channels=self.hidden,
                                                  kernel_size=3,
                                                  padding=1,
                                                  pad_mode='pad',
                                                  has_bias=False),
                                        nn.BatchNorm2d(num_features=self.hidden),
                                        nn.MaxPool2d(kernel_size=2,stride=2),
                                        nn.LeakyReLU(alpha=0.2))
            self.conv_2 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden,
                                                  out_channels=int(self.hidden*1.5),
                                                  kernel_size=3,
                                                  padding=1,
                                                  pad_mode='pad',
                                                  has_bias=False),
                                        nn.BatchNorm2d(num_features=int(self.hidden*1.5)),
                                        nn.MaxPool2d(kernel_size=2,stride=2),
                                        nn.LeakyReLU(alpha=0.2))
            self.conv_3 = nn.SequentialCell(nn.Conv2d(in_channels=int(self.hidden*1.5),
                                                  out_channels=self.hidden*2,
                                                  kernel_size=3,
                                                  padding=1,
                                                  pad_mode='pad',
                                                  has_bias=False),
                                        nn.BatchNorm2d(num_features=self.hidden * 2),
                                        nn.MaxPool2d(kernel_size=2,stride=2),
                                        nn.LeakyReLU(alpha=0.2),
                                        nn.Dropout(0.6))
            self.conv_4 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden*2,
                                                  out_channels=self.hidden*4,
                                                  kernel_size=3,
                                                  padding=1,
                                                  pad_mode='pad',
                                                  has_bias=False),
                                        nn.BatchNorm2d(num_features=self.hidden * 4),    # 16 * 64 * (5 * 5)
                                        nn.MaxPool2d(kernel_size=2,stride=2),
                                        nn.LeakyReLU(alpha=0.2),
                                        nn.Dropout(0.5))
            # self.layer_last = nn.SequentialCell(nn.Dense(in_channels=self.last_hidden * 4,
            #                                       out_channels=self.emb_size, has_bias=True),
            #                                 nn.BatchNorm1d(self.emb_size))
            self.layer_last = nn.Dense(in_channels=self.last_hidden * 4,out_channels=self.emb_size, has_bias=True)
            #self.bn = nn.BatchNorm1d(self.emb_size)
    
        def construct(self, input_data):
            #print("img:",input_data[0])
            x = self.conv_1(input_data)
            x = self.conv_2(x)
            x = self.conv_3(x)
            x = self.conv_4(x)
            #x = ops.Reshape()(x,(x.shape[0],-1))
            print("feat:", input_data[0])
            #x = self.layer_last(x)
            x = self.layer_last(x.view(x.shape[0],-1))
            print("last--------------------------------:",x[0])
            return x
    class NodeUpdateNetwork(nn.Cell):
        def __init__(self,
                     in_features,
                     num_features,
                     ratio=[2, 1],
                     dropout=0.0):
            super(NodeUpdateNetwork, self).__init__()
            # set size
            self.in_features = in_features
            self.num_features_list = [num_features * r for r in ratio]
            self.dropout = dropout
    
            self.eye = ops.Eye()
            self.bmm = ops.BatchMatMul()
            self.cat = ops.Concat(-1)
            self.split = ops.Split(1,2)
            self.repeat = ops.Tile()
            self.unsqueeze = ops.ExpandDims()
            self.squeeze = ops.Squeeze()
            self.transpose = ops.Transpose()
    
    
            # layers
            layer_list = OrderedDict()
            for l in range(len(self.num_features_list)):
    
                layer_list['conv{}'.format(l)] = nn.Conv2d(
                    in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
                    out_channels=self.num_features_list[l],
                    kernel_size=1,
                    has_bias=False)
                layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],)
                layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)
    
                if self.dropout > 0 and l == (len(self.num_features_list) - 1):
                    layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)
    
            self.network = nn.SequentialCell(layer_list)
    
        def construct(self, node_feat, edge_feat):
            # get size
            num_tasks = node_feat.shape[0]
            num_data = node_feat.shape[1]
    
            # get eye matrix (batch_size x 2 x node_size x node_size)
            diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(num_data,num_data,ms.float32),0),0),(num_tasks,2,1,1))
    
            # set diagonal as zero and normalize 原论文是l1归一化
            # edge_feat = edge_feat * diag_mask
            # edge_feat = edge_feat / ops.clip_by_value(ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1),Tensor(0,ms.float32),Tensor(num_data,ms.float32))
    
            edge_feat = ops.L2Normalize(-1)(edge_feat * diag_mask)
    
            # compute attention and aggregate
            aggr_feat = self.bmm(self.squeeze(ops.Concat(2)(self.split(edge_feat))),node_feat)
            node_feat = self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]).swapaxes(1,2)
            #node_feat = self.transpose(self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]),(0,2,1))
    
            node_feat = self.network(self.unsqueeze(node_feat,(-1))).swapaxes(1,2).squeeze()
            #node_feat = self.squeeze(self.transpose(self.network(self.unsqueeze(node_feat,(-1))),(0,2,1,3)))
    
            return node_feat
    
    
    class EdgeUpdateNetwork(nn.Cell):
        def __init__(self,
                     in_features,
                     num_features,
                     ratio=[2, 2, 1, 1],
                     separate_dissimilarity=False,
                     dropout=0.0):
            super(EdgeUpdateNetwork, self).__init__()
            # set size
            self.in_features = in_features
            self.num_features_list = [num_features * r for r in ratio]
            self.separate_dissimilarity = separate_dissimilarity
            self.dropout = dropout
    
            self.eye = ops.Eye()
            self.repeat = ops.Tile()
            self.unsqueeze = ops.ExpandDims()
    
    
            # layers
            layer_list = OrderedDict()
            for l in range(len(self.num_features_list)):
                # set layer
                layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
                                                           out_channels=self.num_features_list[l],
                                                           kernel_size=1,
                                                           has_bias=False)
                layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
                                                                )
                layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)
    
                if self.dropout > 0:
                    layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)
    
            layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
                                               out_channels=1,
                                               kernel_size=1)
            self.sim_network = nn.SequentialCell(layer_list)
    
    
        def construct(self, node_feat, edge_feat):
            # compute abs(x_i, x_j)
    
            x_i = ops.ExpandDims()(node_feat,2)
            x_j = x_i.swapaxes(1,2)
            #x_j = ops.Transpose()(x_i,(0,2,1,3))
            #x_ij = (x_i-x_j)**2
            x_ij = ops.Abs()(x_i-x_j)
            #print("x_ij:",x_ij[0,0,:,:])
            x_ij = ops.Transpose()(x_ij,(0,3,2,1))
            sim_val = self.sim_network(x_ij)
    
            sim_val = ops.Sigmoid()(sim_val)
            #print("sim_val", sim_val[0, 0, :, :])
    
            dsim_val = 1.0 - sim_val
    
            diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),0),(node_feat.shape[0],2,1,1))
            edge_feat = edge_feat * diag_mask
            merge_sum = ops.ReduceSum(keep_dims=True)(edge_feat,-1)
            # set diagonal as zero and normalize
            # edge_feat = ops.Concat(1)([sim_val,dsim_val])*edge_feat
            # edge_feat = edge_feat / ops.clip_by_value((ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1)),Tensor(0,ms.float32),Tensor(num_data,ms.float32))
            # edge_feat = edge_feat*merge_sum
    
            edge_feat = ops.L2Normalize(-1)(ops.Concat(1)([sim_val,dsim_val])*edge_feat)*merge_sum
    
            force_edge_feat = self.repeat(self.unsqueeze(ops.Concat(0)([self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),self.unsqueeze(ops.Zeros()((node_feat.shape[1],node_feat.shape[1]),ms.float32),0)]),0),(node_feat.shape[0],1,1,1))
    
            edge_feat = edge_feat + force_edge_feat
            edge_feat = edge_feat + 1e-6
            #print("sum_edge",self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))[0,0])
            edge_feat = edge_feat / self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))
    
            return edge_feat
    
    
    class GraphNetwork(nn.Cell):
        def __init__(self,
                     in_features,
                     node_features,
                     edge_features,
                     num_layers,
                     dropout=0.0
                     ):
            super(GraphNetwork, self).__init__()
            # set size
            self.in_features = in_features
            self.node_features = node_features
            self.edge_features = edge_features
            self.num_layers = num_layers
            self.dropout = dropout
            self.layers = nn.CellList()
            # for each layer
            for l in range(self.num_layers):
                # set edge to node
                edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,
                                                  num_features=self.node_features,
                                                  dropout=self.dropout if l < self.num_layers-1 else 0.0)
    
                # set node to edge
                node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,
                                                  num_features=self.edge_features,
                                                  separate_dissimilarity=False,
                                                  dropout=self.dropout if l < self.num_layers-1 else 0.0)
                self.layers.append(nn.CellList([edge2node_net,node2edge_net]))
        # forward
        def construct(self, node_feat, edge_feat):
            # for each layer
            edge_feat_list = []
            #print("node_feat---------------------------------------------------------- -1", node_feat[0, 0, :])
            for l in range(self.num_layers):
                # (1) edge to node
                node_feat = self.layers[l][0](node_feat, edge_feat)
                # (2) node to edge
                edge_feat = self.layers[l][1](node_feat, edge_feat)
                # save edge feature
                edge_feat_list.append(edge_feat)
    
            return edge_feat_list
    

    【操作步骤&问题现象】

    我们代码主要功能是用4层卷积加一层全连接层提取图片特征,之后将图片的特征当成图网络每个节点,用GNN。(代码在附件上)

    1、在训练了很多个batch之后,提取出来的特征(经过了4层卷积层和全连接层)出现了很大很大的值,之后几个batch后出现NAN,而在没有经过全连接层的时候,特征数字还是正常的

    2、

    【截图信息】

    这是代码输出的特征

    last--------------------------------: [ 1.918492   -0.8280923   2.0575197   0.3089749  -1.0514854   0.5368729
      0.14135109  1.5270222  -1.4794292  -1.4336827   1.0335447  -0.7093582
     -0.41919574 -0.5667086  -0.3535831   1.5567536   0.5002996  -1.4093596
      0.9674009  -0.18156137  0.14888959  0.6358457   1.406878   -0.03820777
     -0.24577822 -0.25783274  0.5756687  -1.4558431  -1.1002262   0.68062806
     -1.6467474   0.88712454  0.3551372  -1.3449378  -1.7011788  -0.8629771
     -0.92482185  0.9867192  -1.5548937   1.340383   -2.299356   -0.3421743
      1.3239275  -1.3792732  -0.31955895 -0.58364254 -3.7381008  -1.2121737
     -0.75104207 -0.7562581   0.04980466  0.45131734 -1.2448095  -0.33418307
      0.86268485 -1.3601649   1.2753168   2.469506   -1.7358601  -2.9104383
     -0.07392117 -0.73263663  0.11657254 -0.05724781  0.34374043 -0.31884825
      0.13456154  2.3561432  -0.18908082  0.5410311   1.7249999   0.9508886
     -0.30631644  1.6836481   1.1513023  -0.33672807 -0.889638   -0.76715356
     -0.7316199   1.597606   -1.6586273   0.4502733   0.5224928  -3.5851111
     -2.906651   -1.5284328   0.83426046  1.354644   -1.4453334   2.0504599
     -1.3200179  -0.50427496  0.97681373  0.30048305  0.17170379  0.8179815
     -0.92994857  1.333491   -1.2931286  -0.3569969   2.7953048  -3.352736
      1.878619    2.018083   -1.1191074  -1.1341975   1.4532931  -0.66957355
      2.3269157  -0.4198427   0.7148121   0.5458231  -1.3050007  -0.34666243
      2.519589    0.804219    0.91191477  1.3088121   0.6767241   2.1667008
      0.24471135  1.2600335  -1.8683847   2.5641935  -0.9636249  -1.0340385
     -0.32570755 -1.7694132 ]
    ------------------------------------------
    ------------------------------------------------------------------------------- 1 0.7806913
    ---------------------------------------------
    feat: [[[-1.6726604  -1.6897851  -1.7069099  ...  0.43368444  0.46793392
        0.41655967]
      [-1.7069099  -1.7069099  -1.7069099  ...  0.5364329   0.5193082
        0.4850587 ]
      [-1.7240347  -1.7240347  -1.7069099  ...  0.60493195  0.5535577
        0.4850587 ]
      ...
      [-0.6622999  -0.8335474  -0.8677969  ... -0.02868402  0.00556549
       -0.02868402]
      [-0.6622999  -0.69654936 -0.69654936 ... -0.11430778 -0.11430778
       -0.14855729]
      [-0.95342064 -0.8335474  -0.78217316 ... -0.26843056 -0.30268008
       -0.31980482]]

     [[-1.7556022  -1.7731092  -1.7906162  ... -0.617647   -0.582633
       -0.635154  ]
      [-1.7906162  -1.7906162  -1.7906162  ... -0.512605   -0.512605
       -0.565126  ]
      [-1.8081232  -1.8081232  -1.7906162  ... -0.460084   -0.495098
       -0.565126  ]
      ...
      [-0.28501397 -0.37254897 -0.40756297 ... -1.0028011  -0.9677871
       -1.0203081 ]
      [-0.26750696 -0.33753496 -0.32002798 ... -1.12535    -1.1428571
       -1.160364  ]
      [-0.53011197 -0.53011197 -0.44257697 ... -1.2829131  -1.317927
       -1.317927  ]]

     [[-1.68244    -1.6998693  -1.7172985  ... -1.490719   -1.4558606
       -1.490719  ]
      [-1.7172985  -1.7172985  -1.7172985  ... -1.4732897  -1.4384314
       -1.4732897 ]
      [-1.7347276  -1.7347276  -1.7172985  ... -1.4558606  -1.4732897
       -1.5255773 ]
      ...
      [-1.3338562  -1.4210021  -1.4210021  ... -1.6127234  -1.5430065
       -1.5604358 ]
      [-1.2815686  -1.3512855  -1.3338562  ... -1.6127234  -1.5952941
       -1.6127234 ]
      [-1.5081482  -1.4732897  -1.4210021  ... -1.5778649  -1.6127234
       -1.6301525 ]]]
    last--------------------------------: [-9.7715964e+37 -1.3229437e+37 -1.5262715e+38 -2.5811514e+38
      3.2964988e+38 -7.1266450e+37 -7.2963347e+37 -3.0699307e+38
     -1.6108344e+38  5.8011444e+37 -3.9925391e+37 -9.5891957e+37
     -1.7783365e+38  2.2280316e+38 -4.4186918e+37  3.4825655e+37
      5.8457292e+37  7.2160006e+37  1.4259578e+38  9.4037617e+37
      7.4650717e+37  1.8146209e+37 -2.5143476e+38  2.4387442e+38
     -7.5397363e+37  1.4157064e+38 -1.1084308e+38  1.9522180e+38
      2.5864164e+37 -8.5381704e+37  3.3140050e+36 -1.2379668e+38
     -3.3449897e+37  1.6203643e+38  1.4627435e+38  6.6909600e+37
      6.0661751e+37 -1.2335753e+38  1.3377397e+38 -3.7530971e+37
      3.5314601e+37 -1.4393099e+37           -inf -6.0411279e+37
     -7.0721061e+37  1.5951782e+38  9.0163464e+37  1.3680580e+37
     -1.2254094e+37  1.0919689e+38 -1.5229139e+37 -3.4862508e+36
     -8.9739065e+37  2.8713203e+38  9.4768839e+37  7.8658815e+37
     -2.6619306e+38 -7.8224467e+37  6.8780734e+37            inf
     -9.8889302e+37 -1.9009123e+38 -1.4562352e+38 -4.5324568e+37
     -2.6728082e+38  1.0300855e+38 -5.7767852e+37  1.3662499e+37
     -4.0048543e+37 -3.1911765e+37 -1.9702732e+38 -6.5395945e+37
      1.0223747e+38 -2.8775531e+38 -1.1156091e+38 -1.8772822e+38
      1.2472896e+38  1.2465860e+38 -6.7286062e+37 -8.9167649e+37
     -2.8327554e+37 -2.7379526e+37 -1.5994879e+37  1.1577176e+38
      1.1864721e+38  1.7089999e+38 -1.5323652e+37 -1.5374746e+38
      1.2187025e+38 -8.9546139e+37  1.7550813e+38 -5.7048014e+37
     -8.5996788e+37 -5.2310546e+36 -1.4450948e+37 -1.9950120e+37
      4.2429252e+37 -1.4849557e+38  1.0697206e+38 -7.6313524e+37
               -inf  1.7437526e+38 -1.0569269e+38 -1.5577321e+38
     -7.8117285e+37  6.4801082e+37 -3.3032475e+37 -6.4655517e+36
     -2.3770844e+38  1.0880277e+38  3.6430118e+37 -6.9370110e+37
      8.5146681e+37  1.1550550e+38 -2.5614073e+38 -2.1489826e+38
     -8.3233807e+37  2.7233982e+37 -1.3777926e+38 -9.6201629e+37
     -2.1125345e+38 -1.4252791e+36  3.6633845e+37  2.6106833e+37
      9.6643025e+37 -1.4538810e+37 -1.3660478e+38  1.9220696e+38]

    1   采用warmup调整一下学习率,最大学习率设置为0.01;

    2   采用梯度剪裁方法进行保护;

    3   检查最后是否进行归一处理,估计可能取值范围不在0-1之间。

  • 相关阅读:
    操作系统 一个进程通过内核事件 来控制另一个线程的结束
    LeetCode 1123. Lowest Common Ancestor of Deepest Leaves【树,DFS,BFS,哈希表】1607
    微信小程序
    判断链表中是否有环
    将 ONLYOFFICE 文档编辑器与 Node.js 应用集成
    中秋不加班,猿人永不屈服!!! So,How to celebrate the Mid Autumn Festival?
    「Java分享客栈」Nacos配置中心称王称霸,我Apollo一生也不弱于人!
    直击永悦科技半年报:双轮驱动下的“增长曲线”
    智安网络|面临日益增长的安全威胁:云安全和零信任架构的重要性
    Python中的设计模式 -- 工厂模式
  • 原文地址:https://blog.csdn.net/weixin_45666880/article/details/126270389