• 深度学习基础之参数量(3)


    一般的CNN网络的参数量估计代码

    1. class ResidualBlock(nn.Module):
    2. def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    3. super(ResidualBlock, self).__init__()
    4. print(in_planes, planes, norm_fn, stride)
    5. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
    6. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
    7. self.relu = nn.ReLU(inplace=True)
    8. num_groups = planes // 8
    9. if norm_fn == 'group':
    10. self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
    11. self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
    12. if not stride == 1:
    13. self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
    14. elif norm_fn == 'batch':
    15. self.norm1 = nn.BatchNorm2d(planes)
    16. self.norm2 = nn.BatchNorm2d(planes)
    17. if not stride == 1:
    18. self.norm3 = nn.BatchNorm2d(planes)
    19. elif norm_fn == 'instance':
    20. self.norm1 = nn.InstanceNorm2d(planes)
    21. self.norm2 = nn.InstanceNorm2d(planes)
    22. if not stride == 1:
    23. self.norm3 = nn.InstanceNorm2d(planes)
    24. elif norm_fn == 'none':
    25. self.norm1 = nn.Sequential()
    26. self.norm2 = nn.Sequential()
    27. if not stride == 1:
    28. self.norm3 = nn.Sequential()
    29. if stride == 1:
    30. self.downsample = None
    31. else:
    32. self.downsample = nn.Sequential(
    33. nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
    34. def forward(self, x):
    35. print(x.shape)
    36. #exit()
    37. y = x
    38. y = self.relu(self.norm1(self.conv1(y)))
    39. y = self.relu(self.norm2(self.conv2(y)))
    40. if self.downsample is not None:
    41. x = self.downsample(x)
    42. return self.relu(x + y)
    43. R=ResidualBlock(384, 384, norm_fn='instance', stride=1)
    44. summary(R.to("cuda" if torch.cuda.is_available() else "cpu"), (384, 32, 32))

    transformer结构的参数量的估计结果

    1. import torch
    2. import torch.nn as nn
    3. from thop import profile
    4. from torchsummary import summary
    5. # 定义一个简单的Transformer模型
    6. class Transformer(nn.Module):
    7. def __init__(self, input_dim, hidden_dim, num_heads, num_layers):
    8. super(Transformer, self).__init__()
    9. self.embedding = nn.Embedding(input_dim, hidden_dim)
    10. self.transformer_layers = nn.Transformer(
    11. d_model=hidden_dim,
    12. nhead=num_heads,
    13. num_encoder_layers=num_layers,
    14. num_decoder_layers=num_layers
    15. )
    16. self.fc = nn.Linear(hidden_dim, input_dim)
    17. def forward(self, src, tgt):
    18. src = self.embedding(src)
    19. tgt = self.embedding(tgt)
    20. output = self.transformer_layers(src, tgt)
    21. output = self.fc(output)
    22. return output
    23. # 创建Transformer模型实例
    24. model2 = Transformer(input_dim=512, hidden_dim=512, num_heads=8, num_layers=6)
    25. # 使用thop进行FLOPS估算
    26. flops, params = profile(model2, inputs=(torch.randint(0, 512, (128,)), torch.randint(0, 512, (64,))))
    27. print(f"FLOPS: {flops / 1e9} G FLOPS") # 打印FLOPS,以十亿FLOPS(GFLOPS)为单位
    28. # 计算参数量并打印
    29. num_params = sum(p.numel() for p in model2.parameters() if p.requires_grad)
    30. print(f"Total number of trainable parameters: {num_params}")

  • 相关阅读:
    基于STM32程序万年历液晶1602显示-proteus仿真-源程序
    Node.js躬行记(16)——活动配置化
    Duchefa丨S0188盐酸大观霉素五水合物中英文说明书
    关于Java已死,看看国外开发者怎么说的
    编译原理复习——语法分析(自顶向下)
    MySQL的Redo log 、Undo log、 Binlog
    LeetCode(cai鸟之路)139. 单词拆分
    Multiprocessing package - torch.multiprocessing
    iOS Socket编程入门指北
    面试经典-MySQL篇
  • 原文地址:https://blog.csdn.net/u013590327/article/details/133590669