• pytorch的四个hook函数


      训练神经网络模型有时需要观察模型内部模块的输入输出,或是期望在不修改原始模块结构的情况下调整中间模块的输出,pytorch可以用hook回调函数来实现这一功能。主要使用四个hook注册函数:register_forward_hook、register_forward_pre_hook、register_full_backward_hook、register_full_backward_pre_hook。这四个函数可以被继承nn.Module的任意模块调用,传入hook函数并进行注册,从而在执行该模块的相应阶段调用hook函数实现所需功能。

    register_forward_hook(self, hook, *, prepend, with_kwargs)

      为模块注册一个在该模块前向传播之后执行的回调函数。

      hook(module, args, output):需执行的回调函数对象,module为当前模块引用,args为当前模块前向传播输入,output为当前模块前向传播输出。可以返回修改后的output来修改该模块前向传播输出。

      prepend:将该hook函数放在回调函数列表最前面,从而最先执行,否则放在队列最后。

      with_kwargs:hook函数是否传入关键字参数,如果为True,则hook可以额外增加关键则参数。

      register_forward_hook注册函数本身返回一个handle句柄,可执行handle.remove()将注册的该hook函数移除。

    register_forward_pre_hook(self, hook, *, prepend, with_kwargs)

      为模块注册一个在该模块前向传播之前执行的回调函数。

      hook(module, args):args为该模块前向传播输入。可以返回修改后的args来修改该模块前向传播输入。

      其它参数、特性与前面一致。

    register_full_backward_hook(self, hook, prepend)

      为模块注册一个在该模块反向传播之后执行的回调函数。

      hook(module, grad_input, grad_output):grad_input与grad_output分别为该模块前向传播输入和输出的梯度。可以返回修改后的grad_input来修改该模块前向传播输入的梯度。

    register_full_backward_pre_hook(self, hook, prepend)

      为模块注册一个在该模块反向传播之前执行的回调函数。

      hook(module, grad_output):grad_output为该模块前向传播输出的梯度。可以返回修改后的grad_output来修改这一梯度。

  • 相关阅读:
    Springboot下RedisTemplate的两种序列化方式
    3环境变量
    从零搭建开发脚手架 顺应潮流开启升级 - SpringBoot 从2.x 升级到3.x
    图解HTTP笔记整理(前六章)
    Java设计模式之观察者模式(Observer Pattern)
    Redis HyperLogLog的使用
    论文笔记(二十二):GRiD: GPU-Accelerated Rigid Body Dynamics with Analytical Gradients
    LabVIEW程序代码更新缓慢
    程序员最爱用的在线代码编辑器合集,哪款是你的最爱?
    ASP.NET SignalR
  • 原文地址:https://blog.csdn.net/qq_37189298/article/details/133698900