• 一文理解深度学习框架中的InstanceNorm


    91416067f573651e937cb82111718ce0.png

    撰文|梁德澎

    本文首发于公众号GiantPandaCV

    本文主要推导 InstanceNorm 关于输入和参数的梯度公式,同时还会结合 PyTorch 和 MXNet 里的 InstanceNorm 代码来分析。

    1

    InstanceNorm 与 BatchNorm 的联系

    对一个形状为 (N, C, H, W) 的张量应用 InstanceNorm[4] 操作,其实等价于先把该张量 reshape 为 (1, N * C, H, W)的张量,然后应用 BatchNorm[5] 操作。而 gamma 和 beta 参数的每个通道所对应输入张量的位置都是一致的。

    而 InstanceNorm 与 BatchNorm 不同的地方在于:

    • InstanceNorm 训练与预测阶段行为一致,都是利用当前 batch 的均值和方差计算

    • BatchNorm 训练阶段利用当前 batch 的均值和方差,测试阶段则利用训练阶段通过移动平均统计的均值和方差

    论文[6]中的一张示意图,就很好地解释了两者的联系:

    d9d0599076cfccca2b66494d678c547e.png

    https://arxiv.org/pdf/1803.08494.pdf
    所以 InstanceNorm 对于输入梯度和参数求导过程与 BatchNorm 类似,下面开始进入正题。

    2

    梯度推导过程详解

    在开始推导梯度公式之前,首先约定输入,参数,输出等符号:

    • 输入张量 , 形状为(N, C, H, W),rehape 为 (1, N * C, M) 其中 M=H*W

    • 参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)

    • 参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)

    而输入张量 reshape 成 (1, N * C, M)之后,每个通道上是一个长度为 M 的向量,这些向量之间的计算是不像干的,每个向量计算自己的 normalize 结果。所以求导也是各自独立。因此下面的均值、方差符号约定和求导也只关注于其中一个向量,其他通道上的向量计算都是一样的。

    • 一个向量上的均值 

    • 一个向量上的方差 

    • 一个向量上一个点的 normalize 中间输出 

    • 一个向量上一个点的 normalize 最终输出 ,其中  和  表示这个向量所对应的 gamma 和 beta 参数的通道值。

    • loss 函数的符号约定为 

    gamma 和 beta 参数梯度的推导

    先计算简单的部分,求 loss 对  和  的偏导:

    95aaef6d2174952a72a4b5e4b829103a.png

    其中  表示 gamma 和 beta 参数的第  个通道参与了哪些 batch 上向量的 normalize 计算。

    因为 gamma 和 beta 上的每个通道的参数都参与了 N 个 batch 上 M 个元素 normalize 的计算,所以对每个通道进行求导的时候,需要把所有涉及到的位置的梯度都累加在一起。

    对于  在具体实现的时候,就是对应输出梯度的值,也就是从上一层回传回来的梯度值。

    输入梯度的推导

    对输入梯度的求导是最复杂的,下面的推导都是求 loss 相对于输入张量上的一个点上的梯度,而因为上文已知,每个长度是 M 的向量的计算都是独立的,所以下文也是描述其中一个向量上一个点的梯度公式。具体是计算的时候,是通过向量操作(比如 numpy)来完成所有点的梯度计算。

    先看 loss 函数对于  的求导:

    3dfcb3547495b2c2d74f8ffcc1221a79.png

    而从上文约定的公式可知,对于 

    402 Payment Required

     的计算中涉及到  的有三部分,分别是 、 和 。所以 loss 对于 的偏导可以写成以下的形式:

    ffeb06e4c1591a7fd67c1443be43ec7d.png

    接下来就是,分别求上面式子最后三项的梯度公式。

    第一项梯度推导

    在求第一项的时候,把  和  看做常量,则有:

  • 相关阅读:
    基础ArkTS组件:二维码,滚动条与滑动条,多选框与多选框群组(HarmonyOS学习第三课【3.4】)
    mysql—表单二
    解决flume采集日志使用KafkaChannel写不到hdfs的问题
    VPS和云服务器的区别
    2022国赛数模A题思路以及解析(附源码 可供学习训练使用)
    设计模式之备忘录模式
    [SpringBoot系列]NoSQL数据层解决方案
    【23种设计模式】单一职责原则
    Java学习
    RT-DETR算法改进:最新Inner-IoU损失函数,辅助边界框回归的IoU损失,提升RT-DETR检测器精度
  • 原文地址:https://blog.csdn.net/OneFlow_Official/article/details/123288435