• Pytorch中关于forward函数的理解与用法


    前言

    深入深度学习框架的代码,发现forward函数没有被显示调用

    但代码确重写了forward函数,于是好奇是不是python的魔术方法作用

    1. 问题所示

    代码如下所示:

    class Module(nn.Module):
    	
    	# 初始化
        def __init__(self):
            super(Module, self).__init__()
            # ......
        # 前向传播
        def forward(self, x):
            # ......
            return x
            
    # 输入数据
    data = .....  
    
    # 实例化
    module = Module()
    
    # 前向传播
    module(data)  
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    整个代码串没有显示调用forward函数
    由此引发疑问:

    1. 谁去调用forward函数?
    2. 什么时候调用forward函数?

    2. 原理分析

    回顾python的基础知识:python 类和对象的详细分析
    可以清楚知道对象需要执行方法,在方法中传入参数即可,类似 module.forward(data),但是执行对象(参数)就可成功。

    这也说明:module(data) 等价于 module.forward(data)
    即该代码块调用了forward函数(那他是怎样实现什么时候调用的呢)

    本身Pytorch大部分操作都是通过继承nn.Module类实现,查看其源代码:

    class Module(object):
        def __init__(self):
        def forward(self, *input):
     
        def add_module(self, name, module):
        def cuda(self, device=None):
        def cpu(self):
        def __call__(self, *input, **kwargs):
        def parameters(self, recurse=True):
        def named_parameters(self, prefix='', recurse=True):
        def children(self):
        def named_children(self):
        def modules(self):  
        def named_modules(self, memo=None, prefix=''):
        def train(self, mode=True):
        def eval(self):
        def zero_grad(self):
        def __repr__(self):
        def __dir__(self):
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    内部中有个def __call__(self, *input, **kwargs):函数,默认父类会执行该函数

    大致如下:

    class Module():
        def __call__(self, data):        
            print(data)
            
    module = Module()
    
    # 输出 1
    module(1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    这正说明,深度学习的模型继承了nn.Module类,内部的__call__方法有对forward方法的调用,才不用显式地调用forward方法。
    对此,深度学习的模型框架需要重写构造函数中的__init__函数和forward函数。

    2.1 forward函数理解

    1. 通过module中的__call__方法
    2. __call__方法调用module中的forward方法
    3. forward方法
      —若碰到Module子类,则迭代回馈第一步;
      —若碰到Function子类,则执行第四步;
    4. 调用Function子类中的call方法
    5. __call__方法调用Function中的forward方法
    6. 由于层层嵌套,现在只需回馈上一层的值即可
      ( Function中的forward返回值 ->
      module中的forward返回值 ->
      module中的__call__进行forward_hook返回值)

    代码逻辑如下:

    def __call__(self, *input, **kwargs):
    	# 此处执行forward函数
    	result = self.forward(*input, **kwargs)
    	for hook in self._forward_hooks.values():
    		#将注册的hook拿出来用
    		hook_result = hook(self, input, result)
    
    return result
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 围观角度:所谓的__call__为函数调用,只需要将该类型的对象当做函数使用即可,即 module(data) 等价于 module.forward(data)

    • 宏观角度:当一个类默认实现特殊方法__call__,该类的实例就变成可调用的类型,即对象名() 等价于 对象名.__call__()

    2.2 forward函数用法

    CNN可学习的参数层和不可学习的参数层,大致如下:

    • 可学习的参数:卷积层和全连接层的权重、bias、BatchNorm的β和γ等。
    • 不可学习的参数(超参数):学习率、batch size、weight decay、模型的深度宽度分辨率等。
    • Module类中的init构造函数一般放置可学习的参数,其不可学习的参数如果不放置在init层,则在forward函数中可用nn.functional来代替。
    • forward函数必须重写(实现模型功能,链接各层之间的功能)
  • 相关阅读:
    SpringBoot注册web组件
    Java8 判空这样写,惊艳,又骚气
    Redis数据库角色:不只是缓存,还可以作为主数据库!
    【长文档】进行排版的正确顺序?
    JAVA中小型医院信息管理系统源码 医院系统源码
    MySQL使用简单教程
    FinClip 支持创建 H5应用类小程序;PC 终端 优化升级
    建模杂谈系列156 一个接口服务的改版升级
    为什么六位数高薪仍无法让技术人员感到满足?
    C++实现telnet动态调试模块:将日志输出到telnet终端,通过telnet终端调用自定义注册的函数
  • 原文地址:https://blog.csdn.net/weixin_47872288/article/details/133364787