• 冻结ResNet50前几层并进行迁移学习(PyTorch)


    在PyTorch中,加载ResNet50模型并冻结模型的前几层可以通过以下步骤进行:

    import torch
    from torchvision.models import resnet50
    
    # 设置GPU环境
    use_cuda = True
    device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
    
    # 加载预训练的ResNet50模型
    trained_model = resnet50(pretrained=True)
    
    # 冻结需要保持不变的层,通常是前几个卷积层
    for name, param in trained_model.named_parameters():
        if 'conv1' in name or 'bn1' in name or 'conv2' in name or 'bn2' in name or 'conv3' in name or 'bn3' in name:
            param.requires_grad = False
    
    # 修改最后一层进行微调
    model = nn.Sequential(*list(trained_model.children())[:-1],
                            Flatten(),  # [b,2048]
                            nn.Linear(2048, 4),  # 假设输出类别数为4
                            ).to(device)
    
    # 损失、优化器
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    # 训练模型
    for epoch in range(epochs):
        total_loss = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
    
            optimizer.zero_grad()
    
            # 前向传播
            outputs = model(imgs)
            loss = criterion(outputs, labels)
    
            # 反向传播和优化
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
    
        # 打印每个epoch的损失值
        print(f"Epoch {epoch+1} Loss: {total_loss /len(train_loader)}")
    
    # 保存模型参数
    torch.save(model.state_dict(), 'retrain_resnet50.pth')
    
    # 测试模型
    def evalute(model, loader):
        model.eval()
        correct = 0
        total = len(loader.dataset)
    
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                logits = model(x)
                pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        return correct/total
    
    model.load_state_dict(torch.load('retrain_resnet50.pth'))
    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
  • 相关阅读:
    【R语言】把文件夹下的所有文件提取到特定文件夹
    Pytroch Nerf代码阅读笔记(LLFF 数据集pose 处理和Nerf 网络结构)
    npm install 报错解决记录
    WPF实现树形表格控件(TreeListView)
    Springboot文件上传
    解锁新技能《Redis ACL SETUSER命令》
    与迭代次数有关的一种差值结构
    为什么我的MySQL会抖一下?
    前后端分离,JSON数据如何交互
    使用cpolar发布群晖NAS上的网页 上篇(7.X版)
  • 原文地址:https://blog.csdn.net/weixin_43856668/article/details/134048544