• pytorch的buffer学习整理


    pytorch模型中的buffer

    这段时间忙于做项目,但是在项目中一直在模型构建中遇到buffer数据,所以花点时间整理下模型中的parameter和buffer数据的区别💕

    1.torch.nn.Module.named_buffers(prefix=‘‘, recurse=True)

    贴上pytorch官网对其的说明:
    在这里插入图片描述
    官网翻译:

    named_buffers(prefix='', recurse=True)
    方法: named_buffers(prefix='', recurse=True)
    
        Returns an iterator over module buffers, yielding both the name of the buffer as well 
        as the buffer itself.
        返回一个迭代器,该迭代器能够遍历模块的缓冲buffer,并且迭代返回的结果是缓冲的名字和缓冲本身.
        Parameters  参数
                prefix (str) – prefix to prepend to all buffer names.
                prefix (字符串) – 添加到所有缓冲名字之前的前缀.
                recurse (bool)if True, then yields buffers of this module and all submodules. 
                Otherwise, yields only buffers that are direct members of this module.
                recurse (布尔类型) – 如果该参数是True,那么表示递归地迭代返回,即迭代返回该模块的缓冲以及
                该模块的所有子模块的缓冲. 默认为True
        Yields  迭代返回
            (string, torch.Tensor) – Tuple containing the name and buffer
            (字符串,torch.Tensor类型) - 包含缓冲名字和缓冲自身的元组
            
        Example:  例子:
    
        >>> for name, buf in self.named_buffers():
        >>>    if name in ['running_var']:
        >>>        print(buf.size())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    总结,缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer.因而也无法被state_dict()、buffers()、named_buffers()访问到.此外state_dict()可以遍历缓冲buffer和参数Parameter.
    可以概括为,缓冲buffer和参数Parameter的区别是前者不需要训练优化,而后者需要训练优化.在创建方法上也有区别,前者必须要将一个张量使用方法register_buffer()来登记注册,后者比较灵活,可以直接赋值给模块的属性,也可以使用方法register_parameter()来登记注册.
    下面使用代码测试一下buffer数据:

    import torch 
    import torch.nn as nn
    torch.manual_seed(seed=20200910)
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model,self).__init__()
            self.conv1=torch.nn.Sequential(  # 输入torch.Size([64, 1, 28, 28])
                    torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
                    torch.nn.ReLU(),  # 输出torch.Size([64, 64, 28, 28])
            )
            self.attribute_buffer_in = torch.randn(3,5)                       # 仅仅赋值给模型属性,是无法访问到该buffer数据
            register_buffer_in_temp = torch.randn(4,6)               
            self.register_buffer('register_buffer_in', register_buffer_in_temp)   # 注册buffer数据,才能生效,能获取到数据
    
        def forward(self,x): 
            pass
    
    print('cuda(GPU)是否可用:',torch.cuda.is_available())
    print('torch的版本:',torch.__version__)
    model = Model() #.cuda()
    
    
    
    print('初始化之后模型修改之前'.center(100,"-"))
    print('调用named_buffers()'.center(100,"-"))   
    for name, buf in model.named_buffers():
        print(name,'-->',buf.shape)
    
    print('调用named_parameters()'.center(100,"-"))
    for name, param in model.named_parameters():     # 访问模型的parameter参数数据的名字和其本身
        print(name,'-->',param.shape)
    
    print('调用buffers()'.center(100,"-"))           # 访问模型中的buffer数据本身
    for buf in model.buffers():
        print(buf.shape)
    
    print('调用parameters()'.center(100,"-"))        # 访问模型中的parameter数据本身
    for param in model.parameters():
        print(param.shape)
    
    print('调用state_dict()'.center(100,"-"))        # 同时获取模型的parameter参数数据、buffer参数数据
    for k, v in model.state_dict().items():
        print(k, '-->', v.shape)
    
    
    
    model.attribute_buffer_out = torch.randn(10,10)      # 赋值给模型属性
    register_buffer_out_temp = torch.randn(15,15)
    model.register_buffer('register_buffer_out', register_buffer_out_temp)  # 通过注册的方式,使得模型的buffer成员属性生效
    print('模型初始化以及修改之后'.center(100,"-"))
    print('调用named_buffers()'.center(100,"-"))         # 修改模型buffer属性之后,访问buffer数据名字和其本身
    for name, buf in model.named_buffers():
        print(name,'-->',buf.shape)
    
    print('调用named_parameters()'.center(100,"-"))      # 修改模型buffer属性之后,访问模型parameter数据名字和其本身
    for name, param in model.named_parameters():
        print(name,'-->',param.shape)
    
    print('调用buffers()'.center(100,"-"))
    for buf in model.buffers():
        print(buf.shape)
    
    print('调用parameters()'.center(100,"-"))
    for param in model.parameters():
        print(param.shape)
    
    print('调用state_dict()'.center(100,"-"))
    for k, v in model.state_dict().items():
        print(k, '-->', v.shape)  
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69

    输出结果为:

    Windows PowerShell
    版权所有 (C) Microsoft Corporation。保留所有权利。
    
    尝试新的跨平台 PowerShell https://aka.ms/pscore6
    
    加载个人及系统配置文件用了 840 毫秒。
    (base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
    (ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '63490' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
    cuda(GPU)是否可用: True
    torch的版本: 1.2.0+cu92
    --------------------------------------------初始化之后模型修改之前---------------------------------------------
    -----------------------------------------调用named_buffers()------------------------------------------
    register_buffer_in --> torch.Size([4, 6])                     # 
    ----------------------------------------调用named_parameters()----------------------------------------
    conv1.0.weight --> torch.Size([64, 1, 3, 3])
    conv1.0.bias --> torch.Size([64])
    --------------------------------------------调用buffers()---------------------------------------------
    torch.Size([4, 6])
    -------------------------------------------调用parameters()-------------------------------------------
    torch.Size([64, 1, 3, 3])
    torch.Size([64])
    -------------------------------------------调用state_dict()-------------------------------------------
    register_buffer_in --> torch.Size([4, 6])
    conv1.0.weight --> torch.Size([64, 1, 3, 3])
    conv1.0.bias --> torch.Size([64])
    --------------------------------------------模型初始化以及修改之后---------------------------------------------
    -----------------------------------------调用named_buffers()------------------------------------------
    register_buffer_in --> torch.Size([4, 6])
    register_buffer_out --> torch.Size([15, 15])
    ----------------------------------------调用named_parameters()----------------------------------------
    conv1.0.weight --> torch.Size([64, 1, 3, 3])
    conv1.0.bias --> torch.Size([64])
    --------------------------------------------调用buffers()---------------------------------------------
    torch.Size([4, 6])
    torch.Size([15, 15])
    -------------------------------------------调用parameters()-------------------------------------------
    torch.Size([64, 1, 3, 3])
    torch.Size([64])
    -------------------------------------------调用state_dict()-------------------------------------------
    register_buffer_in --> torch.Size([4, 6])
    register_buffer_out --> torch.Size([15, 15])
    conv1.0.weight --> torch.Size([64, 1, 3, 3])
    conv1.0.bias --> torch.Size([64])
    (ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    模型中的buffer和parameter区别

    在这里插入图片描述
    在这里插入图片描述
    下面使用代码进行说明:
    pytorch保存模型参数的一种方式为:

    # save
    torch.save(model.state_dict(), PATH)
    
    # load
    model = MyModel(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    可以看到模型保存的是 model.state_dict() 的返回对象。 model.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数,例如:

    class MyModule(nn.Module):
        def __init__(self, input_size, output_size):
            super(MyModule, self).__init__()
            self.lin = nn.Linear(input_size, output_size)
        def forward(self, x):
            return self.lin(x)
    
    module = MyModule(4, 2)
    print(module.state_dict())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    输出结果:
    在这里插入图片描述
    分析一个parameter和buffer的例子:

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            buffer = torch.randn(2, 3)  # tensor
            self.register_buffer('my_buffer', buffer)
            self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量
    
        def forward(self, x):
            # 可以通过 self.param 和 self.my_buffer 访问
            pass
    model = MyModel()
    for param in model.parameters():
        print(param)
    print("----------------")
    for buffer in model.buffers():
        print(buffer)
    print("----------------")
    print(model.state_dict())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    输出结果:
    在这里插入图片描述

    在这里插入图片描述

  • 相关阅读:
    常用git命令
    java版Spring Cloud之Spark 离线开发框架设计与实现
    基于条件谱矩的时间序列分析(以轴承故障诊断为例,MATLAB)
    第二证券:产业资本真金白银传递市场信心
    蓝桥杯前端Web赛道-输入搜索联想
    socket编程常用API
    python 经典案例(3)
    我的前端开发技巧
    Rust如何开发eBPF应用?(一)
    kafka详解(二)--kafka为什么快
  • 原文地址:https://blog.csdn.net/qq_38765642/article/details/128013565