• 记录PyTorch中半精度amp训练出现Nan的排查过程


    网络的教程来看,在半精度amp训练出现nan问题,无非就是这几种:

    • 计算loss 时,出现了除以0的情况
    • loss过大,被半精度判断为inf
    • 网络参数中有nan,那么运算结果也会输出nan(这个更像是现象而不是原因,网络中出现nan肯定是之前出现了nan或inf)

    但是总结起来就三种:

    • 运算错误,比如计算Loss时出现x/0造成错误
    • 数值溢出,运算结果超出了表示范围,比如权重和输入正常,但是运算结果Nan或Inf。比如loss过大其实就是超出表示范围变成inf
    • 梯度问题,可能梯度回传出现问题(不了解)

    0、结论

    先说结论,我使用amp半精度训练,即中间会参杂float16数据类型,加快训练过程。

    但是本文出现Nan就是因为float16,因为float16支持的最大值在65504,而我的模型中涉及一个矩阵乘法(其实就是transformer中的q@k运算)。其中,a∈[-38,40],b∈[-39,40],而矩阵乘法a@b=c,c∈[-61408,inf]。因为a和b的矩阵乘法运算后最大值超过了float16最大表示,造成出现inf,所以最终结果出现Nan。

    1、粗定位

    一个训练的过程可以表示为以下流程:
    在这里插入图片描述

    1.1 定位到epoch

    首先看到,在epoch4的输出loss是正常的,意味着在epoch4中的0~498iter的训练过程中正常,那么问题就可能出现在epoch4第499iter和epoch0~499iter这501个iter之中。

    1.2 定位到iter

    现在我们需要定位到具体iter。

    可以根据二分法进行判断,debug模型,在epoch=5轮次中的100iter、300iter、499iter分别查看loss是否正常,依此类推定位到具体的iter。

    我的是在epoch=5的iter161~162之间,iter=161时loss正常,iter=162时loss为Nan。iter还是遵从上图的流程,可以看到问题无非出现在iter=161时的梯度计算和权重更新,以及iter=162的前向运算和损失计算,这4处。

    1.3 定位到具体步骤

    在debug时直接暂停到epoch=5和iter=162的前向运算之前。

    首先来看权重是否正常:

    # 在iter=162的模型推理之前,检查权重是否存在异常值,比如Nan或inf
    if epoch == 5:
        if i == 162:
            print(epoch, i)
    
            class bcolors:
                HEADER = '\033[95m'
                OKBLUE = '\033[94m'
                OKGREEN = '\033[92m'
                WARNING = '\033[93m'
                FAIL = '\033[91m'
                ENDC = '\033[0m'
                BOLD = '\033[1m'
                UNDERLINE = '\033[4m'
    
            # print grad check
            v_n = []
            v_v = []
            v_g = []
            for name, parameter in model.named_parameters():
                v_n.append(name)
                v_v.append(parameter.detach().cpu().numpy() if parameter is not None else [0])
                v_g.append(parameter.grad.detach().cpu().numpy() if parameter.grad is not None else [0])
            for j in range(len(v_n)):
                if np.isnan(np.max(v_v[j]).item() - np.min(v_v[j]).item()) or np.isnan(
                        np.max(v_g[j]).item() - np.min(v_g[j]).item()):
                    color = bcolors.FAIL + '*'
                else:
                    color = bcolors.OKGREEN + ' '
                print('%svalue %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_v[j]).item(), np.max(v_v[j]).item()))
                print('%sgrad  %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_g[j]).item(), np.max(v_g[j]).item()))
    
    outputs = model(images)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    通过检查,证明权重没有问题,所以问题被限定iter=162的前向推理和损失计算两处

    检查输入和输出
    通过代码:

    print(images.mean())	# 检查输入,正常
    outputs = model(images)
    print(outputs .mean())	# 检查输出,Nan
    
    • 1
    • 2
    • 3

    由此我们知道了情况:模型权重正常,模型输入正常,但是模型的输出Nan

    2、精确定位

    到这里就好办了,借助pycharm,我们一步一步调试在模型中各个模型的输入输出,看看到底是在模型的哪一个部分出现了Nan或者Inf,最终定位到一行代码:

    attn = (q @ k.transpose(-2, -1)) * self.scale
    
    • 1

    这句代码是想实现q和k的矩阵乘法, 他们的值域分别为:

    张量max(约值)min(约值)
    q38-37
    k40-38
    attninf-61408

    从这里可以发现,就是单纯的计算问题,一种很常见的就是数值溢出,考虑到我使用半精度float16,通过查询其最大值是65504,所以很有可能是最大值溢出了。为了验证,我们可以在计算前,将q和k转为double(float64),可以发现其计算结果正常了,类型也是float64。这表明就是因为数值溢出造成的。
    在这里插入图片描述

    3、解决办法

    现在已知我的原因是数值溢出,一种方法是截取:将inf或nan设置为一个常量,我则在运算前将q和k进行norm归一化到[-1,1],这样保证了运算结果不会太大(没有什么原因,就是无脑操作,不建议学习)。

  • 相关阅读:
    时序数据库 TimescaleDB 基础概念
    奇安信java面试
    【5G NR】3GPP常用协议整理
    基于Tree-LSTM网络语义表示模型
    2022 年最值得关注的颠覆性技术
    索尼 toio™ 应用创意开发征文|探索创新的玩乐世界——索尼 toio™
    Linux/CentOS 安装 flutter 与 jenkins 构建 (踩坑)
    (还在纠结Notability、Goodnotes和Marginnote吗?)iPad、安卓平板、Windows学习软件推荐
    SSM大学生兼职管理系统
    一个支持IPFS的电子邮件——SKIFF
  • 原文地址:https://blog.csdn.net/qq_40243750/article/details/128207067