• 【MindSpore易点通】如何使用溢出检测工具定位精度问题


    一、问题背景

    开发人员在调试模型时可能会遇到一些精度异常情况,如loss无法收敛、精度不达标等,原因可能是计算过程中发生了溢出。为了识别计算过程是否发生了溢出,可以使用溢出检测工具监测网络前反向过程中每个算子的溢出情况。

    溢出检测工具一旦检测到计算过程中有溢出,便会通过[WARNING]消息告知用户具体哪个算子计算时发生了溢出,并保存该算子对应的输入和输出数据(通过Dump方式保存),供用户分析。

    注:当前溢出检测工具仅支持检测上溢出,暂不支持检测下溢出。

    **上溢出:**数值超出最大表示范围。如表示区间为0~65504,那么当进行50000 + 50000的操作时,则会发生上溢出,输出结果可能是65500。

    **下溢出:**数值小于最小表示精度。如最小表示精度为6e-8,那么当进行8e-8 / 10的操作时,则会发生下溢出,输出结果可能是0。

    二、溢出检测工具使用流程

    2.1、示例代码

    以下代码对FP16格式的数据进行40000+41000的求和操作,理论结果为81000,但超过了FP16的最大表示值65504,因此求和计算过程中将发生上溢出。

    下面将以该代码为示例,展示溢出检测工具的使用流程。

    1. import mindspore.nn as nn
    2. import numpy as npfrom mindspore
    3. import Tensorfrom mindspore
    4. import contextimport mindspore.ops as ops
    5. context.set_context(reserve_class_name_in_scope=False)
    6. class PrintDemo(nn.Cell):
    7.         def __init__(self):
    8.             super(PrintDemo, self).__init__()
    9.         def construct(self, input_pra):
    10.             out = input_pra.sum()
    11.             return out
    12. def test():
    13.     input_x = Tensor(np.array([40000, 0, 41000]).astype(np.float16))
    14.     input_pra = Tensor(input_x)
    15.     net = PrintDemo()
    16.     net(input_pra)
    17.     return net(input_pra)
    18. print(test())

    2.2、创建配置文件data_dump.json

    溢出检测是异步Dump下的一个子功能;通过修改data_dump.json文件中op_debug_mode 的值开启溢出检测工具。

    1. {
    2.     "common_dump_settings": {
    3.         "dump_mode": 0,
    4.         "path": "/absolute_path/",
    5.         "net_name": "ResNet50",
    6.         "iteration": "0|5-8|100-120",
    7.         "input_output": 0,
    8.         "kernels": ["Default/Conv-op12"],
    9.         "support_device": [0,1,2,3,4,5,6,7],
    10.         "op_debug_mode": 3
    11.     }
    12. }

    • dump_mode:设置成0即可,表示Dump出该网络中的所有算子数据。
    • path:Dump保存数据的绝对路径。此流程执行完毕后,此网络中若有算子溢出则会在路径path下生成该算子的Dump数据。
    • net_name:自定义的网络名称,随便填,例如:”ResNet50”。
    • iteration:指定需要进行溢出检测的迭代。类型为str,用“|”分离要保存的不同区间的step的数据。如”0|5-8|100-120”表示Dump第1个,第6个到第9个, 第101个到第121个step的数据。指定“all”,表示Dump所有迭代的数据。这里推荐设置成“all”。
    • input_output:设置成0,表示Dump出算子的输入和算子的输出;设置成1,表示Dump出算子的输入;设置成2,表示Dump出算子的输出。
    • kernels:算子的名称列表。使用溢出检测工具时不用特别关注。
    • support_device:设置成[0,1,2,3,4,5,6,7]即可,表示开启溢出检测的device id。
    • op_debug_mode:设置成3即可,表示开启全部溢出检测功能。
    • op_debug_mode:该属性用于算子溢出调试,设置成0,表示不开启溢出;设置成1,表示开启AiCore溢出检测;设置成2,表示开启Atomic溢出检测;设置成3,表示开启全部溢出检测功能。在Dump数据的时候请设置成0,若设置成其他值,则只会Dump溢出算子的数据。

    2.3、设置环境变量

    export MINDSPORE_DUMP_CONFIG={Absolute path of data_dump.json}

    在网络脚本执行前,设置好环境变量;否则溢出检测工具将不生效。

    在分布式场景下,该环境变量需要在调用mindspore.communication.init之前配置。

    2.4、执行训练

    正常启动训练即可。

    **注:**可以在训练脚本中设置context.set_context(reserve_class_name_in_scope=False),避免Dump文件名称过长导致Dump数据文件生成失败。

    2.5、若发生溢出,告知用户并自动保存溢出算子数据

    若算子有溢出,则溢出检测工具会将算子溢出信息通过打屏日志[WARNING]显示,溢出检测工具执行流程以及溢出信息展示如下图中红框所示:

    2.6、解析保存的溢出算子数据

    步骤一:找到在json文件中所配置的"path": "/absolute_path/"路径下所生成rank_x文件夹下算子数据文件。使用run包中提供的msaccucmp.py解析Dump出来的文件。不同的环境上msaccucmp.py文件所在的路径可能不同,可以通过find命令进行查找:

    find ${run_path} -name "msaccucmp.py"

    run_path:run包的安装路径,如/usr/local/Ascend

    步骤二:找到msaccucmp.py后,到/absolute_path目录下,运行如下命令解析Dump数据:

    python ${The absolute path of msaccucmp.py} convert -d {file path of dump} -out {file path of output}

    解析完成后,会生成input与output文件,具体格式如以下所示:

    1. ReduceSum.Default_ReduceSum-op0.2.7.1641535332858320.input.0.npy
    2. ReduceSum.Default_ReduceSum-op0.2.7.1641535332858320.output.0.npy

    注:input.0.npy文件为ReduceSum算子的输入,即代码中的input_x;output.0.npy文件为所存储数据为算子运行后计算的结果。

    步骤三:解析npy文件分析溢出情况。通过numpy.load("file_name")工具加载out文件可以读取到对应数据。

    1、使用加载工具解析input.0.npy文件。

    1. import numpy
    2. numpy.load("ReduceSum.Default_ReduceSum-op0.2.7.1641535332858320.input.0.npy")

    根据解析文件可以看到该文件所存储的数据为input.0.npy文件为ReduceSum算子的输入,即代码中的input_x;

    2、使用加载工具解析output.0.npy文件,解析方法以及代码如上所示,故在此不再赘述;

    文件解析后的结果为[65500.],即算子溢出后的输出。

  • 相关阅读:
    大数据开发(HBase面试真题-卷二)
    jsencrypt.js加密java后端解密
    linux系统Jenkins工具的node节点配置
    云数据库技术行业动态@2022-09-30
    〖Python 数据库开发实战 - MySQL篇⑮〗- 数据表结果集的排序与去除重复(去重)
    使用 Python 连接到 PostgreSQL 数据库
    根据json生成Java类
    Mysql8.1.0 windows 绿色版安装
    Vue是什么?
    Java诊断利器Arthas安装和使用
  • 原文地址:https://blog.csdn.net/Kenji_Shinji/article/details/127617527