• BN体系理解——类封装复现


     

     

     

     

     

    1. from pathlib import Path
    2. from typing import Optional
    3. import torch
    4. import torch.nn as nn
    5. from torch import Tensor
    6. class BN(nn.Module):
    7. def __init__(self,num_features,momentum=0.1,eps=1e-8):##num_features是通道数
    8. """
    9. 初始化方法
    10. :param num_features:特征属性的数量,也就是通道数目C
    11. """
    12. super(BN, self).__init__()
    13. ##register_buffer:将属性当成parameter进行处理,唯一的区别就是不参与反向传播的梯度求解
    14. self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
    15. self.register_buffer('running_var', torch.zeros(1, num_features, 1, 1))
    16. self.running_mean: Optional[Tensor]
    17. self.running_var: Optional[Tensor]
    18. self.running_mean=torch.zeros([1,num_features,1,1])
    19. self.running_var=torch.zeros([1,num_features,1,1])
    20. self.gamma=nn.Parameter(torch.ones([1,num_features,1,1]))
    21. self.beta=nn.Parameter(torch.zeros(1,num_features,1,1))
    22. self.eps=eps
    23. self.momentum=momentum
    24. def forward(self,x):
    25. """
    26. 前向过程
    27. output=(x-μ)/α*γ+β
    28. :param x: [N,C,H,W]
    29. :return: [N,C,H,W]
    30. """
    31. if self.training:
    32. #训练阶段--》使用当前批次的数据
    33. _mean=torch.mean(x,dim=(0,2,3),keepdim=True)
    34. _var = torch.var(x, dim=(0, 2, 3), keepdim=True)
    35. #将训练过程中的均值和方差保存下来--方便推理的时候使用--》滑动平均
    36. self.running_mean=self.momentum*self.running_mean+(1.0-self.momentum)*_mean
    37. self.running_var=self.momentum*self.running_var+(1.0-self.momentum)*_var
    38. else:
    39. #推理阶段-->使用的是训练过程中的累积数据
    40. _mean=self.running_mean
    41. _var=self.running_var
    42. z=(x-_mean)/torch.sqrt(_var+self.eps)*self.gamma+self.beta
    43. return z
    44. if __name__ == '__main__':
    45. torch.manual_seed(28)
    46. path_dir=Path("./output/models")
    47. path_dir.mkdir(parents=True,exist_ok=True)
    48. device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    49. bn=BN(num_features=12)
    50. bn.to(device)#只针对子模块和参数进行转换
    51. #模拟训练过程
    52. bn.train()
    53. xs=[torch.randn(8,12,32,32).to(device) for _ in range(10)]
    54. for _x in xs:
    55. bn(_x)
    56. print(bn.running_mean.view(-1))
    57. print(bn.running_var.view(-1))
    58. #模拟推理过程
    59. bn.eval()
    60. _r=bn(xs[0])
    61. print(_r.shape)
    62. bn=bn.cpu()#保存都是以cpu保存,恢复再自己转回GPU上
    63. #模拟模型保存
    64. torch.save(bn,str(path_dir/'bn_model.pkl'))
    65. #state_dict:获取当前模块的所有参数(Parameter+register_buffer)
    66. torch.save(bn.state_dict(),str(path_dir/"bn_params.pkl"))
    67. #pt结构的保存
    68. traced_script_module=torch.jit.trace(bn.eval(),xs[0].cpu())
    69. traced_script_module.save("./output/bn_model.pt")
    70. #模拟模型恢复
    71. bn_model=torch.load(str(path_dir/"bn_model.pkl"),map_location='cpu')
    72. bn_params=torch.load(str(path_dir/"bn_params.pkl"),map_location='cpu')
    73. print(len(bn_params))

  • 相关阅读:
    Linux 命令 —— feh
    货币银行学简答论述题
    上传项目的全部依赖到maven私有仓库-nexus
    EPOLL(C/S模型)实现I/O复用多进程聊天室,通过共享内存、socketpair实现父子进程通信,通过信号量回收进程
    高速串行总线——SATA
    带头双向循环链表的实现(C语言)
    搭建Android自动化python+appium环境
    【深度思考】聊聊CGLIB动态代理原理
    在浏览器中输入url回车后发生了什么
    CSDN课程推荐:《【专题】SecureBoot精讲》系列课程上线
  • 原文地址:https://blog.csdn.net/weixin_42601270/article/details/133771187