• 一文弄懂CNN中的BatchNorm


    1. 引言

    本文重点介绍BatchNorm的定义和相关特性,并介绍了其详细实现和具体应用。希望可以帮助大家加深对其理解。

    嗯嗯,闲话少说,我们直接开始吧!

    2. 什么是BatchNorm

    BatchNorm是2015年提出的网络层,这个层具有以下特性:

    • 易于训练:由于网络权重的分布随这一层的变化小得多,因此我们可以使用更高的学习率。我们在训练中收敛的方向没有那么不稳定,这样我们就可以更快地朝着loss收敛的方向前进。

    • 提升正则化:尽管网络在每个epoch都会遇到相同的训练样本,但每个小批量的归一化是不同的,因此每次都会稍微改变其值。

    • 提升精度:可能是由于前面两点的结合,论文提到他们获得了比当时最先进的结果更好的准确性。

    3. BatchNorm是如何工作的?

    BatchNorm所做的是确保接收到的输入具有平均值0和标准偏差1。
    本文中介绍的算法如下:
    在这里插入图片描述
    下面是我自己用pytorch进行的实现:

    import numpy as np
    import torch
    from torch import nn
    from torch.nn import Parameter
    
    class BatchNorm(nn.Module):
        def __init__(self, num_features, eps=1e-5, momentum=0.1):
            super().__init__()
            self.gamma = Parameter(torch.Tensor(num_features))
            self.beta = Parameter(torch.Tensor(num_features))
            self.register_buffer("moving_avg", torch.zeros(num_features))
            self.register_buffer("moving_var", torch.ones(num_features))
            self.register_buffer("eps", torch.tensor(eps))
            self.register_buffer("momentum", torch.tensor(momentum))
            self._reset()
        
        def _reset(self):
            self.gamma.data.fill_(1)
            self.beta.data.fill_(0)
        
        def forward(self, x):
            if self.training:
                mean = x.mean(dim=0)
                var = x.var(dim=0)
                self.moving_avg = self.moving_avg * momentum + mean * (1 - momentum)
                self.moving_var = self.moving_var * momentum + var * (1 - momentum)
            else:
                mean = self.moving_avg
                var = self.moving_var
                
            x_norm = (x - mean) / (torch.sqrt(var + self.eps))
            return x_norm * self.gamma + self.beta
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    这里对其进行补充说明如下:

    • 我们在训练和推理过程中BatchNorm有不同的行为。在训练中,我们记录均值和方差的指数移动平均值,以供以后在推理时使用。其原因是,在训练期间处理批次时,我们可以获得输入随时间变化的均值和方差的更好估计,然后将其用于推理。在推理过程中使用输入批次的平均值和方差将不太准确,因为其大小可能比训练中使用的小得多,大数定律在这里发挥了作用。

    4. 什么时候使用Batchnorm ?

    这似乎总是有帮助的,所以没有理由不使用它。通常它出现在全连接层/卷积层和激活函数之间。但也有人认为,最好把它放在激活层之后。我找不到任何关于激活函数之后使用它的论文,所以最安全的选择是按照每个人的做法,在激活函数前使用它。

    5. 一些技巧总结

    列举下关于实际应用中BatchNorm的技巧总结如下:

    • 我们知道,一个已经训练的网络包含用于训练它的数据集的移动平均值和方差,这可能是一个问题。在迁移学习期间,我们通常会冻结大部分层,如果不小心,BatchNorm层也会冻结,这意味着应用的移动平均值属于原始数据集,而不是新数据集。解冻BatchNorm层是一个好主意,将允许网络重新计算自己数据集上的移动平均值和方差。
  • 相关阅读:
    Streaming Systems
    DStream转换介绍_大数据培训
    亚马逊、Shopee、美客多店铺出单量如何提高?有何方法?
    Mysql内置函数
    外汇天眼:乐天证券扩大了交易工具!进入数字资产市场!
    C++类与对象(3)—拷贝构造函数&运算符重载
    FastAPI获年度第一新兴框架,2021年最受欢迎的TOP 100开发工具出炉
    SpringBoot如何缓存方法返回值?
    R语言比较两个样本的均值是否相同:使用t.test函数执行t检验通过比较来自两个总体的样本数据判断总体均值是否相同
    《痞子衡嵌入式半月刊》 第 54 期
  • 原文地址:https://blog.csdn.net/sgzqc/article/details/127952552