撰文|梁德澎
本文首发于公众号GiantPandaCV
本文主要推导 InstanceNorm 关于输入和参数的梯度公式,同时还会结合 PyTorch 和 MXNet 里的 InstanceNorm 代码来分析。
对一个形状为 (N, C, H, W) 的张量应用 InstanceNorm[4] 操作,其实等价于先把该张量 reshape 为 (1, N * C, H, W)的张量,然后应用 BatchNorm[5] 操作。而 gamma 和 beta 参数的每个通道所对应输入张量的位置都是一致的。
而 InstanceNorm 与 BatchNorm 不同的地方在于:
InstanceNorm 训练与预测阶段行为一致,都是利用当前 batch 的均值和方差计算
BatchNorm 训练阶段利用当前 batch 的均值和方差,测试阶段则利用训练阶段通过移动平均统计的均值和方差
论文[6]中的一张示意图,就很好地解释了两者的联系:
在开始推导梯度公式之前,首先约定输入,参数,输出等符号:
输入张量 , 形状为(N, C, H, W),rehape 为 (1, N * C, M) 其中 M=H*W
参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)
参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)
而输入张量 reshape 成 (1, N * C, M)之后,每个通道上是一个长度为 M 的向量,这些向量之间的计算是不像干的,每个向量计算自己的 normalize 结果。所以求导也是各自独立。因此下面的均值、方差符号约定和求导也只关注于其中一个向量,其他通道上的向量计算都是一样的。
一个向量上的均值
一个向量上的方差
一个向量上一个点的 normalize 中间输出
一个向量上一个点的 normalize 最终输出 ,其中 和 表示这个向量所对应的 gamma 和 beta 参数的通道值。
loss 函数的符号约定为
gamma 和 beta 参数梯度的推导
先计算简单的部分,求 loss 对 和 的偏导:
其中 表示 gamma 和 beta 参数的第 个通道参与了哪些 batch 上向量的 normalize 计算。
因为 gamma 和 beta 上的每个通道的参数都参与了 N 个 batch 上 M 个元素 normalize 的计算,所以对每个通道进行求导的时候,需要把所有涉及到的位置的梯度都累加在一起。
对于 在具体实现的时候,就是对应输出梯度的值,也就是从上一层回传回来的梯度值。
对输入梯度的求导是最复杂的,下面的推导都是求 loss 相对于输入张量上的一个点上的梯度,而因为上文已知,每个长度是 M 的向量的计算都是独立的,所以下文也是描述其中一个向量上一个点的梯度公式。具体是计算的时候,是通过向量操作(比如 numpy)来完成所有点的梯度计算。
先看 loss 函数对于 的求导:
而从上文约定的公式可知,对于
的计算中涉及到 的有三部分,分别是 、 和 。所以 loss 对于 的偏导可以写成以下的形式:
接下来就是,分别求上面式子最后三项的梯度公式。
在求第一项的时候,把 和 看做常量,则有: