• PyTorch函数中的__call__和forward函数


    初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录

    init & call

    代码:

    class A():
        def __init__(self):
            print('init函数')        
            
        def __call__(self, param):
            print('call 函数', param)
    a = A()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出
    在这里插入图片描述
    分析:A进行类的实例化,生成对象a,这个过程自动调用_init_(),没有调用_call_()


    上面的代码加一行

    class A():
        def __init__(self):
            print('init函数') 
            
        def __call__(self, param):
            print('call 函数', param)
    a = A()
    a(1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    输出
    在这里插入图片描述
    分析:a是对象,python中让对象有了像函数一样加括号(参数)的功能,使用这种功能时,自动调用_call_()


    _ call_()中可以调用其它函数,如forward函数

    class A():
        def __init__(self):
            print('init函数')
            
        def __call__(self, param):
            print('call 函数', param)
            res = self.forward(param)
            return res + 2
            
        def forward(self, input_): 
            print('forward 函数', input_)
            return input_
        
    a = A()
    b = a(1)
    print('结果b =',b)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述
    分析:_call _()成功调用了forward(),且返回值给了b


    另外我之前有个误解,以为该类的值只有参数声明了才能用,这是错误的

    class A():
        def __init__(self):
            print('init函数')
            self.a = 100  # 声明参数a
            
        def __call__(self, param):
            print('call 函数', param)
            res = self.forward(param)
            return res + 2
            
        def forward(self, input_): 
            print('forward 函数', input_, self.a)
            return input_
        
    a = A()
    b = a(1)
    print('结果b =',b)
    print(a.a)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    在这里插入图片描述


    nn.Module

    看了上面的例子,就知道了_call _()的作用,那下面看更CNN的例子

    from torch import nn
    import torch
    
    class Ding(nn.Module):
        def __init__(self):
            print('init')
            super().__init__()
        
        def forward(self, input):
            output = input + 1
            print("forward")
            return output
    
    dzy = Ding()
    x = torch.tensor(1.0)
    out = dzy(x)
    print(out)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    结果:
    在这里插入图片描述
    分析:
    这里并没有调用_call_() 和forward(),但还是显示了forward,原因是:Ding这个子类继承了父类nn.Module里的call函数,接下来去源码看
    在这里插入图片描述
    发现_call_调用了_call_impl这个函数,相当于起了个外号一样,那就去这个函数看

    在这里插入图片描述
    在这里插入图片描述

    这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward()
    在这里插入图片描述
    所以: _call _()用了forward,而这个父类的forward在子类中重写了(简单代码)
    在这里插入图片描述


    当然,也可以重写__call__(),比如我们不让它使用forward()

    from torch import nn
    import torch
    
    class Ding(nn.Module):
        def __init__(self):
            print('init')
            super().__init__()
            
        def __call__(self, input_):
            print('重写call, 不用forward')
            return 'hhh'
            
        def forward(self, input):
            output = input + 1
            print("forward")
            return output
    
    dzy = Ding()
    x = torch.tensor(1.0)
    out = dzy(x)
    print(out)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    在这里插入图片描述

    总结

    使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。

    参考

    https://blog.csdn.net/dss_dssssd/article/details/83750838
    https://zhuanlan.zhihu.com/p/366461413

  • 相关阅读:
    网络安全(黑客)自学
    Perl 语言入门教程
    【VUE】ElementPlus之动态主题色调切换(Vue3 + Element Plus+Scss + Pinia)
    从.net开发做到云原生运维(零)——序
    深度学习常见损失函数总结+Pytroch实现
    SVC服务的发布
    CSS 圆角渐变边框
    获取当周和上周的周一、周日时间
    PX4模块设计之十一:Built-In框架
    Java web程序实现请求后重启服务动作
  • 原文地址:https://blog.csdn.net/qq_43745026/article/details/125537774