【功能模块】完整代码在附件,数据集需要的话也可以提供
class EmbeddingImagenet(nn.Cell):
def __init__(self,emb_size,cifar_flag=False):
super(EmbeddingImagenet, self).__init__()
# set size
self.hidden = 64
self.last_hidden = self.hidden * 25 if not cifar_flag else self.hidden * 4
self.emb_size = emb_size
self.out_dim = emb_size
# set layers
self.conv_1 = nn.SequentialCell(nn.Conv2d(in_channels=3,
out_channels=self.hidden,
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=self.hidden),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2))
self.conv_2 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden,
out_channels=int(self.hidden*1.5),
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=int(self.hidden*1.5)),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2))
self.conv_3 = nn.SequentialCell(nn.Conv2d(in_channels=int(self.hidden*1.5),
out_channels=self.hidden*2,
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=self.hidden * 2),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2),
nn.Dropout(0.6))
self.conv_4 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden*2,
out_channels=self.hidden*4,
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=self.hidden * 4), # 16 * 64 * (5 * 5)
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2),
nn.Dropout(0.5))
# self.layer_last = nn.SequentialCell(nn.Dense(in_channels=self.last_hidden * 4,
# out_channels=self.emb_size, has_bias=True),
# nn.BatchNorm1d(self.emb_size))
self.layer_last = nn.Dense(in_channels=self.last_hidden * 4,out_channels=self.emb_size, has_bias=True)
#self.bn = nn.BatchNorm1d(self.emb_size)
def construct(self, input_data):
#print("img:",input_data[0])
x = self.conv_1(input_data)
x = self.conv_2(x)
x = self.conv_3(x)
x = self.conv_4(x)
#x = ops.Reshape()(x,(x.shape[0],-1))
print("feat:", input_data[0])
#x = self.layer_last(x)
x = self.layer_last(x.view(x.shape[0],-1))
print("last--------------------------------:",x[0])
return x
class NodeUpdateNetwork(nn.Cell):
def __init__(self,
in_features,
num_features,
ratio=[2, 1],
dropout=0.0):
super(NodeUpdateNetwork, self).__init__()
# set size
self.in_features = in_features
self.num_features_list = [num_features * r for r in ratio]
self.dropout = dropout
self.eye = ops.Eye()
self.bmm = ops.BatchMatMul()
self.cat = ops.Concat(-1)
self.split = ops.Split(1,2)
self.repeat = ops.Tile()
self.unsqueeze = ops.ExpandDims()
self.squeeze = ops.Squeeze()
self.transpose = ops.Transpose()
# layers
layer_list = OrderedDict()
for l in range(len(self.num_features_list)):
layer_list['conv{}'.format(l)] = nn.Conv2d(
in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
out_channels=self.num_features_list[l],
kernel_size=1,
has_bias=False)
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],)
layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)
if self.dropout > 0 and l == (len(self.num_features_list) - 1):
layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)
self.network = nn.SequentialCell(layer_list)
def construct(self, node_feat, edge_feat):
# get size
num_tasks = node_feat.shape[0]
num_data = node_feat.shape[1]
# get eye matrix (batch_size x 2 x node_size x node_size)
diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(num_data,num_data,ms.float32),0),0),(num_tasks,2,1,1))
# set diagonal as zero and normalize 原论文是l1归一化
# edge_feat = edge_feat * diag_mask
# edge_feat = edge_feat / ops.clip_by_value(ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1),Tensor(0,ms.float32),Tensor(num_data,ms.float32))
edge_feat = ops.L2Normalize(-1)(edge_feat * diag_mask)
# compute attention and aggregate
aggr_feat = self.bmm(self.squeeze(ops.Concat(2)(self.split(edge_feat))),node_feat)
node_feat = self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]).swapaxes(1,2)
#node_feat = self.transpose(self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]),(0,2,1))
node_feat = self.network(self.unsqueeze(node_feat,(-1))).swapaxes(1,2).squeeze()
#node_feat = self.squeeze(self.transpose(self.network(self.unsqueeze(node_feat,(-1))),(0,2,1,3)))
return node_feat
class EdgeUpdateNetwork(nn.Cell):
def __init__(self,
in_features,
num_features,
ratio=[2, 2, 1, 1],
separate_dissimilarity=False,
dropout=0.0):
super(EdgeUpdateNetwork, self).__init__()
# set size
self.in_features = in_features
self.num_features_list = [num_features * r for r in ratio]
self.separate_dissimilarity = separate_dissimilarity
self.dropout = dropout
self.eye = ops.Eye()
self.repeat = ops.Tile()
self.unsqueeze = ops.ExpandDims()
# layers
layer_list = OrderedDict()
for l in range(len(self.num_features_list)):
# set layer
layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
out_channels=self.num_features_list[l],
kernel_size=1,
has_bias=False)
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
)
layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)
if self.dropout > 0:
layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)
layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
out_channels=1,
kernel_size=1)
self.sim_network = nn.SequentialCell(layer_list)
def construct(self, node_feat, edge_feat):
# compute abs(x_i, x_j)
x_i = ops.ExpandDims()(node_feat,2)
x_j = x_i.swapaxes(1,2)
#x_j = ops.Transpose()(x_i,(0,2,1,3))
#x_ij = (x_i-x_j)**2
x_ij = ops.Abs()(x_i-x_j)
#print("x_ij:",x_ij[0,0,:,:])
x_ij = ops.Transpose()(x_ij,(0,3,2,1))
sim_val = self.sim_network(x_ij)
sim_val = ops.Sigmoid()(sim_val)
#print("sim_val", sim_val[0, 0, :, :])
dsim_val = 1.0 - sim_val
diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),0),(node_feat.shape[0],2,1,1))
edge_feat = edge_feat * diag_mask
merge_sum = ops.ReduceSum(keep_dims=True)(edge_feat,-1)
# set diagonal as zero and normalize
# edge_feat = ops.Concat(1)([sim_val,dsim_val])*edge_feat
# edge_feat = edge_feat / ops.clip_by_value((ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1)),Tensor(0,ms.float32),Tensor(num_data,ms.float32))
# edge_feat = edge_feat*merge_sum
edge_feat = ops.L2Normalize(-1)(ops.Concat(1)([sim_val,dsim_val])*edge_feat)*merge_sum
force_edge_feat = self.repeat(self.unsqueeze(ops.Concat(0)([self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),self.unsqueeze(ops.Zeros()((node_feat.shape[1],node_feat.shape[1]),ms.float32),0)]),0),(node_feat.shape[0],1,1,1))
edge_feat = edge_feat + force_edge_feat
edge_feat = edge_feat + 1e-6
#print("sum_edge",self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))[0,0])
edge_feat = edge_feat / self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))
return edge_feat
class GraphNetwork(nn.Cell):
def __init__(self,
in_features,
node_features,
edge_features,
num_layers,
dropout=0.0
):
super(GraphNetwork, self).__init__()
# set size
self.in_features = in_features
self.node_features = node_features
self.edge_features = edge_features
self.num_layers = num_layers
self.dropout = dropout
self.layers = nn.CellList()
# for each layer
for l in range(self.num_layers):
# set edge to node
edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,
num_features=self.node_features,
dropout=self.dropout if l < self.num_layers-1 else 0.0)
# set node to edge
node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,
num_features=self.edge_features,
separate_dissimilarity=False,
dropout=self.dropout if l < self.num_layers-1 else 0.0)
self.layers.append(nn.CellList([edge2node_net,node2edge_net]))
# forward
def construct(self, node_feat, edge_feat):
# for each layer
edge_feat_list = []
#print("node_feat---------------------------------------------------------- -1", node_feat[0, 0, :])
for l in range(self.num_layers):
# (1) edge to node
node_feat = self.layers[l][0](node_feat, edge_feat)
# (2) node to edge
edge_feat = self.layers[l][1](node_feat, edge_feat)
# save edge feature
edge_feat_list.append(edge_feat)
return edge_feat_list
【操作步骤&问题现象】
我们代码主要功能是用4层卷积加一层全连接层提取图片特征,之后将图片的特征当成图网络每个节点,用GNN。(代码在附件上)
1、在训练了很多个batch之后,提取出来的特征(经过了4层卷积层和全连接层)出现了很大很大的值,之后几个batch后出现NAN,而在没有经过全连接层的时候,特征数字还是正常的
2、
【截图信息】
这是代码输出的特征
last--------------------------------: [ 1.918492 -0.8280923 2.0575197 0.3089749 -1.0514854 0.5368729
0.14135109 1.5270222 -1.4794292 -1.4336827 1.0335447 -0.7093582
-0.41919574 -0.5667086 -0.3535831 1.5567536 0.5002996 -1.4093596
0.9674009 -0.18156137 0.14888959 0.6358457 1.406878 -0.03820777
-0.24577822 -0.25783274 0.5756687 -1.4558431 -1.1002262 0.68062806
-1.6467474 0.88712454 0.3551372 -1.3449378 -1.7011788 -0.8629771
-0.92482185 0.9867192 -1.5548937 1.340383 -2.299356 -0.3421743
1.3239275 -1.3792732 -0.31955895 -0.58364254 -3.7381008 -1.2121737
-0.75104207 -0.7562581 0.04980466 0.45131734 -1.2448095 -0.33418307
0.86268485 -1.3601649 1.2753168 2.469506 -1.7358601 -2.9104383
-0.07392117 -0.73263663 0.11657254 -0.05724781 0.34374043 -0.31884825
0.13456154 2.3561432 -0.18908082 0.5410311 1.7249999 0.9508886
-0.30631644 1.6836481 1.1513023 -0.33672807 -0.889638 -0.76715356
-0.7316199 1.597606 -1.6586273 0.4502733 0.5224928 -3.5851111
-2.906651 -1.5284328 0.83426046 1.354644 -1.4453334 2.0504599
-1.3200179 -0.50427496 0.97681373 0.30048305 0.17170379 0.8179815
-0.92994857 1.333491 -1.2931286 -0.3569969 2.7953048 -3.352736
1.878619 2.018083 -1.1191074 -1.1341975 1.4532931 -0.66957355
2.3269157 -0.4198427 0.7148121 0.5458231 -1.3050007 -0.34666243
2.519589 0.804219 0.91191477 1.3088121 0.6767241 2.1667008
0.24471135 1.2600335 -1.8683847 2.5641935 -0.9636249 -1.0340385
-0.32570755 -1.7694132 ]
------------------------------------------
------------------------------------------------------------------------------- 1 0.7806913
---------------------------------------------
feat: [[[-1.6726604 -1.6897851 -1.7069099 ... 0.43368444 0.46793392
0.41655967]
[-1.7069099 -1.7069099 -1.7069099 ... 0.5364329 0.5193082
0.4850587 ]
[-1.7240347 -1.7240347 -1.7069099 ... 0.60493195 0.5535577
0.4850587 ]
...
[-0.6622999 -0.8335474 -0.8677969 ... -0.02868402 0.00556549
-0.02868402]
[-0.6622999 -0.69654936 -0.69654936 ... -0.11430778 -0.11430778
-0.14855729]
[-0.95342064 -0.8335474 -0.78217316 ... -0.26843056 -0.30268008
-0.31980482]]
[[-1.7556022 -1.7731092 -1.7906162 ... -0.617647 -0.582633
-0.635154 ]
[-1.7906162 -1.7906162 -1.7906162 ... -0.512605 -0.512605
-0.565126 ]
[-1.8081232 -1.8081232 -1.7906162 ... -0.460084 -0.495098
-0.565126 ]
...
[-0.28501397 -0.37254897 -0.40756297 ... -1.0028011 -0.9677871
-1.0203081 ]
[-0.26750696 -0.33753496 -0.32002798 ... -1.12535 -1.1428571
-1.160364 ]
[-0.53011197 -0.53011197 -0.44257697 ... -1.2829131 -1.317927
-1.317927 ]]
[[-1.68244 -1.6998693 -1.7172985 ... -1.490719 -1.4558606
-1.490719 ]
[-1.7172985 -1.7172985 -1.7172985 ... -1.4732897 -1.4384314
-1.4732897 ]
[-1.7347276 -1.7347276 -1.7172985 ... -1.4558606 -1.4732897
-1.5255773 ]
...
[-1.3338562 -1.4210021 -1.4210021 ... -1.6127234 -1.5430065
-1.5604358 ]
[-1.2815686 -1.3512855 -1.3338562 ... -1.6127234 -1.5952941
-1.6127234 ]
[-1.5081482 -1.4732897 -1.4210021 ... -1.5778649 -1.6127234
-1.6301525 ]]]
last--------------------------------: [-9.7715964e+37 -1.3229437e+37 -1.5262715e+38 -2.5811514e+38
3.2964988e+38 -7.1266450e+37 -7.2963347e+37 -3.0699307e+38
-1.6108344e+38 5.8011444e+37 -3.9925391e+37 -9.5891957e+37
-1.7783365e+38 2.2280316e+38 -4.4186918e+37 3.4825655e+37
5.8457292e+37 7.2160006e+37 1.4259578e+38 9.4037617e+37
7.4650717e+37 1.8146209e+37 -2.5143476e+38 2.4387442e+38
-7.5397363e+37 1.4157064e+38 -1.1084308e+38 1.9522180e+38
2.5864164e+37 -8.5381704e+37 3.3140050e+36 -1.2379668e+38
-3.3449897e+37 1.6203643e+38 1.4627435e+38 6.6909600e+37
6.0661751e+37 -1.2335753e+38 1.3377397e+38 -3.7530971e+37
3.5314601e+37 -1.4393099e+37 -inf -6.0411279e+37
-7.0721061e+37 1.5951782e+38 9.0163464e+37 1.3680580e+37
-1.2254094e+37 1.0919689e+38 -1.5229139e+37 -3.4862508e+36
-8.9739065e+37 2.8713203e+38 9.4768839e+37 7.8658815e+37
-2.6619306e+38 -7.8224467e+37 6.8780734e+37 inf
-9.8889302e+37 -1.9009123e+38 -1.4562352e+38 -4.5324568e+37
-2.6728082e+38 1.0300855e+38 -5.7767852e+37 1.3662499e+37
-4.0048543e+37 -3.1911765e+37 -1.9702732e+38 -6.5395945e+37
1.0223747e+38 -2.8775531e+38 -1.1156091e+38 -1.8772822e+38
1.2472896e+38 1.2465860e+38 -6.7286062e+37 -8.9167649e+37
-2.8327554e+37 -2.7379526e+37 -1.5994879e+37 1.1577176e+38
1.1864721e+38 1.7089999e+38 -1.5323652e+37 -1.5374746e+38
1.2187025e+38 -8.9546139e+37 1.7550813e+38 -5.7048014e+37
-8.5996788e+37 -5.2310546e+36 -1.4450948e+37 -1.9950120e+37
4.2429252e+37 -1.4849557e+38 1.0697206e+38 -7.6313524e+37
-inf 1.7437526e+38 -1.0569269e+38 -1.5577321e+38
-7.8117285e+37 6.4801082e+37 -3.3032475e+37 -6.4655517e+36
-2.3770844e+38 1.0880277e+38 3.6430118e+37 -6.9370110e+37
8.5146681e+37 1.1550550e+38 -2.5614073e+38 -2.1489826e+38
-8.3233807e+37 2.7233982e+37 -1.3777926e+38 -9.6201629e+37
-2.1125345e+38 -1.4252791e+36 3.6633845e+37 2.6106833e+37
9.6643025e+37 -1.4538810e+37 -1.3660478e+38 1.9220696e+38]
1 采用warmup调整一下学习率,最大学习率设置为0.01;
2 采用梯度剪裁方法进行保护;
3 检查最后是否进行归一处理,估计可能取值范围不在0-1之间。