目前在做实验,参考了一个新的网络架构之后发现训练时损失出现Nan,参数了出现了inf的情况,先说说我的排查经历。
首先肯定是打印损失,损失是最容易出现Nan的,有各种原因,网上也有很多解决办法,我这里就不一一赘述了,大伙打开CSDN就一搜就有很多很全的
我的问题是在训练的中间参数中出现了inf,导致最终的损失为NaN或者inf
用下面的代码判断参数是否出现了NaN或者inf
- for i in range(5): # exam是一个参数列表
- if torch.isnan(exeam[i]).any(): print('下表为{}的元素存在NaN!'.format(i))
- if torch.isinf(exeam[i]).any(): print('下表为{}的元素存在inf!'.format(i))
确定是哪些为NaN之后,直接上Relu或者归一化,很可惜,没用。。。
我参考的文章是这两篇
- # feat是网络输出的结果,10通道
- K, atp, tran, B = torch.split(feat, (1, 3, 3, 3), dim=1)
-
- # x是网络的输入
- atp = K * atp - atp + x
- tran = K * tran - tran + x
- x = K * x - B + x
- # H, W是限定的尺寸
- rgb = x[:, :, :H, :W]
- atp = atp[:, :, :H, :W]
- tran = tran[:, :, :H, :W]
这样一弄,问题就解决了
我分析了下原因,代码中的参数出现NaN是因为出现了除以0的情况,加上了soft reconstruction之后(类似于全局残差,关键是后面加上x的那个操作)是原先为0的参数变得不为0了,除以0的情况消失了,就不存在NaN啦。
我是做视觉方向的,全局残差机制(ResNet,FFA-Net)在视觉中可谓是有百利而无一害,所以这样加应该没问题的
疑问