• 6-2 pytorch中训练模型的3种方法


    Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。(养成自己的习惯)
    有3类典型的训练循环代码风格脚本形式训练循环,函数形式训练循环,类形式训练循环。
    下面以minist数据集的多分类模型的训练为例,演示这3种训练模型的风格。
    其中类形式训练循环我们同时演示torchkeras.KerasModel和torchkeras.LightModel两种示范。

    准备数据

    transform = transforms.Compose([transforms.ToTensor()])
    
    ds_train = torchvision.datasets.MNIST(root="./data/mnist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./data/mnist/",train=False,download=True,transform=transform)
    
    dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=4)
    
    print(len(ds_train))
    print(len(ds_val))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    image.png

    %matplotlib inline
    %config InlineBackend.figure_format = 'svg'
    
    #查看部分样本
    from matplotlib import pyplot as plt 
    
    plt.figure(figsize=(8,8)) 
    for i in range(9):
        img,label = ds_train[i] 
        img = torch.squeeze(img) # 删除为1的维度
        ax=plt.subplot(3,3,i+1)
        ax.imshow(img.numpy())
        ax.set_title("label = %d"%label)
        ax.set_xticks([])
        ax.set_yticks([]) 
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    image.png

    一、脚本风格

    脚本风格的训练循环非常常见。

    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(64,32))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(32,10))
    
    print(net)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    image.png
    代码量较多,可以查看最下方链接对应的notebook。

    二、函数风格

    该风格在脚本形式上做了进一步的函数封装。

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.layers = nn.ModuleList([
                nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
                nn.MaxPool2d(kernel_size = 2,stride = 2),
                nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
                nn.MaxPool2d(kernel_size = 2,stride = 2),
                nn.Dropout2d(p = 0.1),
                nn.AdaptiveMaxPool2d((1,1)),
                nn.Flatten(),
                nn.Linear(64,32),
                nn.ReLU(),
                nn.Linear(32,10)]
            )
        def forward(self,x):
            for layer in self.layers:
                x = layer(x)
            return x
    net = Net()
    print(net)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    image.png
    代码量较多,可以查看最下方链接对应的notebook。

    三、类风格

    此处使用**torchkeras.KerasModel(其源码其实就是脚本风格中的代码)**高层次API接口中的fit方法训练模型。
    使用该形式训练模型非常简洁明了。
    先构建模型,同一二。

    from torchkeras import KerasModel 
    
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList([
                nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
                nn.MaxPool2d(kernel_size = 2,stride = 2),
                nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
                nn.MaxPool2d(kernel_size = 2,stride = 2),
                nn.Dropout2d(p = 0.1),
                nn.AdaptiveMaxPool2d((1,1)),
                nn.Flatten(),
                nn.Linear(64,32),
                nn.ReLU(),
                nn.Linear(32,10)]
            )
        def forward(self,x):
            for layer in self.layers:
                x = layer(x)
            return x
        
    net = Net() 
    
    print(net)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    使用kerasModel:

    from torchmetrics import Accuracy
    
    model = KerasModel(net,
                       loss_fn=nn.CrossEntropyLoss(),
                       metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=10)},
                       optimizer = torch.optim.Adam(net.parameters(),lr = 0.01)  )
    
    model.fit(
        train_data = dl_train,
        val_data= dl_val,
        epochs=10,
        patience=3,
        monitor="val_acc", 
        mode="max",
        plot=True,
        cpu=True
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    训练过程:
    image.png
    其实编码训练代码按照自己的习惯即可,不必要按照以上三种方式。

    参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

  • 相关阅读:
    uniapp 动态切换应用图标、名称插件(如新年、国庆等) Ba-ChangeIcon
    989. 数组形式的整数加法
    DC-1靶场搭建及渗透实战详细过程(DC靶场系列)
    Unity云原生分布式运行时
    Redis数据库安全之旅
    css的rotate3d实现炫酷的圆环转动动画
    【pytorch问题解决】OSError: [WinError 1455] 页面文件太小,无法完成操作。
    LeetCode刷题笔记【35】:动态规划专题-7(爬楼梯、零钱兑换、完全平方数)
    数组与链表
    使用Python进行广告点击率预测
  • 原文地址:https://blog.csdn.net/hxhabcd123/article/details/132996797