• torchvision.models中模型编辑的requires_grad


    在对torchvision已有模型进行编辑的时候会保存已有训练结果,只针对编辑过的层进行训练,可以通过对requires_grad的赋值实现

    1. import torch
    2. import torchvision
    3. from torch import optim, nn
    4. def InitMode(mode_name):
    5. if mode_name == 'resnet152':
    6. return torchvision.models.resnet152(weights=torchvision.models.ResNet152_Weights.DEFAULT)
    7. elif mode_name == "resnet50":
    8. return torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    9. elif mode_name == "vgg16":
    10. return torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
    11. else:
    12. exit()
    13. mymodel = InitMode("vgg16")
    14. # 修改前
    15. for name, param in mymodel.named_parameters():
    16. print(name, param.requires_grad)
    17. print("=" * 30)
    18. # 修改requires_grad
    19. for param in mymodel.parameters():
    20. param.requires_grad = False
    21. for name, param in mymodel.named_parameters():
    22. print(name, param.requires_grad)
    23. print("=" * 30)
    24. mymodel.classifier[6] = nn.Linear(4096, 10)
    25. # 修改后
    26. for name, param in mymodel.named_parameters():
    27. print(name, param.requires_grad)
    28. exit()

    结果如下:

    D:\anaconda3\envs\pytorch_gpu\python.exe D:/project/python/pytorch_gpu/test.py
    features.0.weight True
    features.0.bias True
    features.2.weight True
    features.2.bias True
    features.5.weight True
    features.5.bias True
    features.7.weight True
    features.7.bias True
    features.10.weight True
    features.10.bias True
    features.12.weight True
    features.12.bias True
    features.14.weight True
    features.14.bias True
    features.17.weight True
    features.17.bias True
    features.19.weight True
    features.19.bias True
    features.21.weight True
    features.21.bias True
    features.24.weight True
    features.24.bias True
    features.26.weight True
    features.26.bias True
    features.28.weight True
    features.28.bias True
    classifier.0.weight True
    classifier.0.bias True
    classifier.3.weight True
    classifier.3.bias True
    classifier.6.weight True
    classifier.6.bias True

    ==============================
    features.0.weight False
    features.0.bias False
    features.2.weight False
    features.2.bias False
    features.5.weight False
    features.5.bias False
    features.7.weight False
    features.7.bias False
    features.10.weight False
    features.10.bias False
    features.12.weight False
    features.12.bias False
    features.14.weight False
    features.14.bias False
    features.17.weight False
    features.17.bias False
    features.19.weight False
    features.19.bias False
    features.21.weight False
    features.21.bias False
    features.24.weight False
    features.24.bias False
    features.26.weight False
    features.26.bias False
    features.28.weight False
    features.28.bias False
    classifier.0.weight False
    classifier.0.bias False
    classifier.3.weight False
    classifier.3.bias False
    classifier.6.weight False
    classifier.6.bias False

    ==============================
    features.0.weight False
    features.0.bias False
    features.2.weight False
    features.2.bias False
    features.5.weight False
    features.5.bias False
    features.7.weight False
    features.7.bias False
    features.10.weight False
    features.10.bias False
    features.12.weight False
    features.12.bias False
    features.14.weight False
    features.14.bias False
    features.17.weight False
    features.17.bias False
    features.19.weight False
    features.19.bias False
    features.21.weight False
    features.21.bias False
    features.24.weight False
    features.24.bias False
    features.26.weight False
    features.26.bias False
    features.28.weight False
    features.28.bias False
    classifier.0.weight False
    classifier.0.bias False
    classifier.3.weight False
    classifier.3.bias False
    classifier.6.weight True
    classifier.6.bias True

    进程已结束,退出代码0
     

    从结果来看,先对模型的 requires_grad 全部赋值到False,其结果从下载的缺省值True变为Flase。

    当对某个层进行编辑后,这个层的requires_grad会自动变为True。

    还不清楚是什么原因,记录一下

  • 相关阅读:
    Spring Boot结合FFmpeg实现视频会议系统视频流处理与优化
    CCF ChinaSoft 2023 论坛巡礼|自动驾驶仿真测试论坛
    相亲交友APP系统|婚恋交友社交软件|语音聊天平台定制开发
    子菜单前带有复选框的创建方法
    多主复制的适用场景(1)-多IDC
    怎么使用动态代理IP提升网络安全,动态代理IP有哪些好处呢
    Find My键盘|苹果Find My技术与键盘结合,智能防丢,全球定位
    ddr系统检验
    听GPT 讲Rust源代码--library/core/src(6)
    Socks5代理、IP代理与其在爬虫开发中的应用
  • 原文地址:https://blog.csdn.net/immc1979/article/details/128099907