• 分类网络搭建示例


    搭建CNN网络

    本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。

    请注意,这里定义的只是一个简陋的版本,后续一些经典网络的学习,我们会在另外单独去开一个专栏讲解。

    1. 网络搭建

    在PyTorch中,你可以使用 torchvision.models 中的 vgg16 来加载预定义的VGG16模型,也可以手动定义。以下是手动定义的一个简化版本:

    1. import torch
    2. import torch.nn as nn
    3. class VGG16(nn.Module):
    4. def __init__(self, num_classes=1000):
    5. super(VGG16, self).__init__()
    6. self.features = nn.Sequential(
    7. nn.Conv2d(3, 64, kernel_size=3, padding=1),
    8. nn.ReLU(inplace=True),
    9. nn.Conv2d(64, 64, kernel_size=3, padding=1),
    10. nn.ReLU(inplace=True),
    11. nn.MaxPool2d(kernel_size=2, stride=2),
    12. nn.Conv2d(64, 128, kernel_size=3, padding=1),
    13. nn.ReLU(inplace=True),
    14. nn.Conv2d(128, 128, kernel_size=3, padding=1),
    15. nn.ReLU(inplace=True),
    16. nn.MaxPool2d(kernel_size=2, stride=2),
    17. nn.Conv2d(128, 256, kernel_size=3, padding=1),
    18. nn.ReLU(inplace=True),
    19. nn.Conv2d(256, 256, kernel_size=3, padding=1),
    20. nn.ReLU(inplace=True),
    21. nn.Conv2d(256, 256, kernel_size=3, padding=1),
    22. nn.ReLU(inplace=True),
    23. nn.MaxPool2d(kernel_size=2, stride=2),
    24. nn.Conv2d(256, 512, kernel_size=3, padding=1),
    25. nn.ReLU(inplace=True),
    26. nn.Conv2d(512, 512, kernel_size=3, padding=1),
    27. nn.ReLU(inplace=True),
    28. nn.Conv2d(512, 512, kernel_size=3, padding=1),
    29. nn.ReLU(inplace=True),
    30. nn.MaxPool2d(kernel_size=2, stride=2),
    31. nn.Conv2d(512, 512, kernel_size=3, padding=1),
    32. nn.ReLU(inplace=True),
    33. nn.Conv2d(512, 512, kernel_size=3, padding=1),
    34. nn.ReLU(inplace=True),
    35. nn.Conv2d(512, 512, kernel_size=3, padding=1),
    36. nn.ReLU(inplace=True),
    37. nn.MaxPool2d(kernel_size=2, stride=2),
    38. )
    39. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
    40. self.classifier = nn.Sequential(
    41. nn.Linear(512 * 7 * 7, 4096),
    42. nn.ReLU(inplace=True),
    43. nn.Dropout(),
    44. nn.Linear(4096, 4096),
    45. nn.ReLU(inplace=True),
    46. nn.Dropout(),
    47. nn.Linear(4096, num_classes),
    48. )
    49. def forward(self, x):
    50. x = self.features(x)
    51. x = self.avgpool(x)
    52. x = torch.flatten(x, 1)
    53. x = self.classifier(x)
    54. return x

    2. 初始化方法

    在这里,我们不再手动初始化每一层,因为PyTorch的默认初始化通常足够好。你可以选择手动初始化,如果需要,可以使用 torch.nn.init 中的不同方法。

    3. 模型的保存

    使用 torch.save 保存VGG16模型:

    1. vgg16 = VGG16()
    2. torch.save(vgg16.state_dict(), 'vgg16_model.pth')

    4. 预训练模型的加载

    要加载预训练的VGG16模型,你可以使用 torchvision.models 中的 vgg16(pretrained=True),或者手动加载预训练权重:

    1. vgg16 = VGG16()
    2. vgg16.load_state_dict(torch.load('pretrained_vgg16.pth'))

    请确保路径 'pretrained_vgg16.pth' 是你预训练模型文件的实际路径。你可以从PyTorch的官方模型库或其他来源下载预训练权重。

    上面是最简单的一种模型全部加载的方式,但也有一些情况下,只是想加载其中一部分层的参数。剩下一部分由于已经改变参数了,无法加载预训练模型,所以要选择随机初始化。 、

    这里我们来观察网络怎么去表示的:

    1. if __name__ == "__main__":
    2. model = VGG16()
    3. for name, value in model.named_parameters():
    4. print(name)

    下面就是控制台打印出的部分信息。 

    这两行的输出就是打印网络层的名字,实际上加载预训练模型时,也是按照这个名字来加载的。

    1. # 加载预训练 VGG16 模型的参数
    2. pretrained_dict = torch.load('pretrained_vgg16.pth')
    3. # 剔除预训练模型中全连接层的参数
    4. pretrained_dict.pop('classifier.0.weight')
    5. pretrained_dict.pop('classifier.0.bias')
    6. pretrained_dict.pop('classifier.3.weight')
    7. pretrained_dict.pop('classifier.3.bias')
    8. pretrained_dict.pop('classifier.6.weight')
    9. pretrained_dict.pop('classifier.6.bias')
    10. # 获取自定义模型的参数字典
    11. model_dict = model.state_dict()
    12. # 更新自定义模型的参数字典,加载预训练模型的参数值
    13. model_dict.update(pretrained_dict)
    14. # 加载更新后的参数字典到自定义模型中
    15. model.load_state_dict(model_dict)

    自己定义的一些层是不会出现在pretrained_dict中,因此会将其剔除,从而只加载了 pretrained_dict中有的层。

    总结

    本章只是对网络的定义进行一个简单的示例,具体的部分我们会在另外一个专栏讲解,这里只是为了让读者了解网络定义的流程。在实际项目中,通常需要更详细的网络结构,包括适当的初始化方法、损失函数的选择、优化器的设置等。如果读者了解掌握了基本的网络定义过程,你可以在本专栏中深入讲解这些方面,以及如何训练和评估模型等内容。

  • 相关阅读:
    Vue2中10种组件通信方式和实践技巧
    【Azure 事件中心】使用Azure AD认证方式创建Event Hub Consume Client + 自定义Event Position
    百度10年架构师分享的(Java TCP/IP Socket编程开发经验)看完受益匪浅!
    状态保持-JWT
    2331. 计算布尔二叉树的值
    CSS 中的 white-space 渲染模型
    Win10重启后总是自动打开上次未关闭的程序怎么办
    Kubernetes学习笔记-保障集群内节点和网络安全20220827
    [附源码]java毕业设计暖暖猫窝系统
    react,三个DatePicker组件实现时间限制(开奖时间>截至时间>开始时间)
  • 原文地址:https://blog.csdn.net/m0_62919535/article/details/134357254