• 【深度学习】实验12 使用PyTorch训练模型


    使用PyTorch训练模型

    PyTorch是一个基于Python的科学计算库,它是一个开源的机器学习框架,由Facebook公司于2016年开源。它提供了构建动态计算图的功能,可以更自然地使用Python语言编写深度神经网络的程序,具有易于使用、灵活、高效等特点,被广泛应用于深度学习任务中。

    PyTorch的核心是动态计算图(Dynamic Computational Graph),这意味着计算图是在运行时动态生成的,而不是预先编译好的。这个特点使得PyTorch具有高度的灵活性,可以更加轻松地进行实验和调试。同时,它也有一个静态计算图模块,可以用于生产环境中,提高计算效率。

    另外,PyTorch的另一个特点是它的张量计算。张量是PyTorch中的核心数据结构,类似于NumPy中的数组。PyTorch支持GPU加速,可以使用GPU进行张量计算,大大提高了计算效率。同时,它也支持自动求导功能,可以自动计算张量的梯度,使得深度学习的模型训练更加便捷。

    PyTorch还提供了丰富的模型库,包括经典的深度学习模型,如卷积神经网络(CNN)、循环神经网络(RNN)和生成对抗网络(GAN),以及各种领域的预训练模型,如自然语言处理(NLP)和计算机视觉(CV),可以快速搭建和训练模型。

    PyTorch也具有良好的社区支持。它的文档详细且易于理解,社区提供了大量的示例和教程,可以帮助用户更好地学习和使用PyTorch。同时,PyTorch还有一个活跃的开发团队,定期发布新的版本,修复bug和增加新的特性,保证了PyTorch的稳定性和可用性。

    总的来说,PyTorch是一个强大、灵活、易于使用的机器学习框架,具有良好的社区支持和广泛的应用领域,能够满足不同用户的需求。随着人工智能的不断发展,PyTorch的应用将会更加广泛。

    1. 线性回归类

    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    class LinearRegression(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = torch.nn.Linear(1, 1)
            self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
            self.loss_function = torch.nn.MSELoss()
        
        def forward(self, x):
            out = self.linear(x)
            return out
       
        def train(self, data, model_save_path='model.path'):
            x = data["x"]
            y = data["y"]
            for epoch in range(10000):
                prediction = self.forward(x)
                loss = self.loss_function(prediction, y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                if epoch % 100 == 0:
                    print("epoch:{}, loss is:{}".format(epoch, loss.item()))
            torch.save(self.state_dict(), "linear.pth")
        def test(self, x, model_path="linear.pth"):
            x = data["x"]
            y = data["y"]
            self.load_state_dict(torch.load(model_path))
            prediction = self.forward(x)
            plt.scatter(x.numpy(), y.numpy(), c=x.numpy())
            plt.plot(x.numpy(), prediction.detach().numpy(), color="r")
            plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    该Python代码实现了一个简单的线性回归模型,并进行了训练和测试。

    首先,导入了PyTorch、NumPy和Matplotlib.pyplot库。

    接下来,定义了一个名为LinearRegression的类,它是一个继承自torch.nn.Module的类,因此可以利用PyTorch的自动求导和优化功能。在该类的初始化方法中,定义了一个torch.nn.Linear对象,它表示一个全连接层,输入大小为1,输出大小为1;并定义了一个torch.optim.SGD对象,它表示随机梯度下降法的优化器,学习率为0.01;以及一个torch.nn.MSELoss对象,它表示均方误差损失函数。

    接下来,定义了一个名为forward的方法,它表示前向传递过程,即对输入进行线性变换,得到输出。

    然后,定义了一个名为train的方法,它接受一个数据字典和一个模型保存路径作为输入。该方法首先从数据字典中获取输入数据x和输出数据y,然后进行10000次迭代训练。在每次迭代中,先将输入数据x送入模型中得到预测输出prediction,然后计算预测输出和真实输出之间的均方误差损失loss,并进行反向传播和参数优化。每100次迭代打印一次损失值。最后将模型参数保存到指定的文件路径中。

    最后,定义了一个名为test的方法,它接受一个输入数据x和一个模型保存路径作为输入。该方法首先从文件中加载训练好的模型参数,然后将输入数据x送入模型中得到预测输出prediction,并将预测输出和真实输出以及输入数据可视化展示出来。

    总之,这段代码实现了一个简单的线性回归模型,并可以通过train方法进行训练,通过test方法进行测试和可视化展示。

    2. 创建数据集

    def create_linear_data(nums_data, if_plot=False):
        x = torch.linspace(0, 1, nums_data)
        x = torch.unsqueeze(x, dim = 1)
        k = 2
        y = k * x + torch.rand(x.size())
        if if_plot:
            plt.scatter(x.numpy(), y.numpy(), c=x.numpy())
            plt.show()
        data = {"x":x, "y":y}
        return data
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    data = create_linear_data(300, if_plot=True)
    
    • 1

    1

    3. 训练模型

    model = LinearRegression()
    model.train(data)
    
    • 1
    • 2
       epoch:0, loss is:3.8653182983398438
       epoch:100, loss is:0.31251025199890137
       epoch:200, loss is:0.2438090741634369
       epoch:300, loss is:0.20671892166137695
       epoch:400, loss is:0.17835141718387604
       epoch:500, loss is:0.15658551454544067
       epoch:600, loss is:0.13988454639911652
       epoch:700, loss is:0.12706983089447021
       epoch:800, loss is:0.11723710596561432
       epoch:900, loss is:0.10969242453575134
       epoch:1000, loss is:0.10390334576368332
       epoch:1100, loss is:0.09946136921644211
       epoch:1200, loss is:0.09605306386947632
       epoch:1300, loss is:0.09343785047531128
       epoch:1400, loss is:0.09143117070198059
       epoch:1500, loss is:0.0898914709687233
       epoch:1600, loss is:0.08871004730463028
       epoch:1700, loss is:0.08780352771282196
       epoch:1800, loss is:0.08710794895887375
       epoch:1900, loss is:0.08657423406839371
       epoch:2000, loss is:0.08616471290588379
       epoch:2100, loss is:0.08585048466920853
       epoch:2200, loss is:0.08560937643051147
       epoch:2300, loss is:0.08542437106370926
       epoch:2400, loss is:0.08528240770101547
       epoch:2500, loss is:0.08517350256443024
       epoch:2600, loss is:0.08508992940187454
       epoch:2700, loss is:0.08502580225467682
       epoch:2800, loss is:0.08497659116983414
       epoch:2900, loss is:0.08493883907794952
       epoch:3000, loss is:0.08490986377000809
       epoch:3100, loss is:0.08488764613866806
       epoch:3200, loss is:0.08487057685852051
       epoch:3300, loss is:0.08485749363899231
       epoch:3400, loss is:0.08484745025634766
       epoch:3500, loss is:0.08483975380659103
       epoch:3600, loss is:0.08483383059501648
       epoch:3700, loss is:0.08482930809259415
       epoch:3800, loss is:0.08482582122087479
       epoch:3900, loss is:0.08482315391302109
       epoch:4000, loss is:0.08482109755277634
       epoch:4100, loss is:0.08481952548027039
       epoch:4200, loss is:0.08481831848621368
       epoch:4300, loss is:0.08481740206480026
       epoch:4400, loss is:0.08481667935848236
       epoch:4500, loss is:0.08481614291667938
       epoch:4600, loss is:0.08481571823358536
       epoch:4700, loss is:0.08481539785861969
       epoch:4800, loss is:0.08481515198945999
       epoch:4900, loss is:0.08481497317552567
       epoch:5000, loss is:0.08481481671333313
       epoch:5100, loss is:0.08481471240520477
       epoch:5200, loss is:0.08481462299823761
       epoch:5300, loss is:0.08481455594301224
       epoch:5400, loss is:0.08481451123952866
       epoch:5500, loss is:0.08481448143720627
       epoch:5600, loss is:0.08481443673372269
       epoch:5700, loss is:0.08481442183256149
       epoch:5800, loss is:0.0848143994808197
       epoch:5900, loss is:0.0848143920302391
       epoch:6000, loss is:0.08481437712907791
       epoch:6100, loss is:0.08481436222791672
       epoch:6200, loss is:0.08481435477733612
       epoch:6300, loss is:0.08481435477733612
       epoch:6400, loss is:0.08481435477733612
       epoch:6500, loss is:0.08481435477733612
       epoch:6600, loss is:0.08481435477733612
       epoch:6700, loss is:0.08481435477733612
       epoch:6800, loss is:0.08481434732675552
       epoch:6900, loss is:0.08481435477733612
       epoch:7000, loss is:0.08481433987617493
       epoch:7100, loss is:0.08481435477733612
       epoch:7200, loss is:0.08481433987617493
       epoch:7300, loss is:0.08481433987617493
       epoch:7400, loss is:0.08481434732675552
       epoch:7500, loss is:0.08481434732675552
       epoch:7600, loss is:0.08481434732675552
       epoch:7700, loss is:0.08481434732675552
       epoch:7800, loss is:0.08481434732675552
       epoch:7900, loss is:0.08481434732675552
       epoch:8000, loss is:0.08481434732675552
       epoch:8100, loss is:0.08481434732675552
       epoch:8200, loss is:0.08481434732675552
       epoch:8300, loss is:0.08481434732675552
       epoch:8400, loss is:0.08481434732675552
       epoch:8500, loss is:0.08481434732675552
       epoch:8600, loss is:0.08481434732675552
       epoch:8700, loss is:0.08481434732675552
       epoch:8800, loss is:0.08481434732675552
       epoch:8900, loss is:0.08481434732675552
       epoch:9000, loss is:0.08481434732675552
       epoch:9100, loss is:0.08481434732675552
       epoch:9200, loss is:0.08481434732675552
       epoch:9300, loss is:0.08481434732675552
       epoch:9400, loss is:0.08481434732675552
       epoch:9500, loss is:0.08481434732675552
       epoch:9600, loss is:0.08481434732675552
       epoch:9700, loss is:0.08481434732675552
       epoch:9800, loss is:0.08481434732675552
       epoch:9900, loss is:0.08481434732675552
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    model.test(data)
    
    • 1

    4. 测试模型

    2

    附:系列文章

    序号文章目录直达链接
    1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
    2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
    3特征处理https://want595.blog.csdn.net/article/details/132182165
    4交叉验证https://want595.blog.csdn.net/article/details/132182238
    5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
    6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
    7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
    8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
    9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
    10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
    11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
    12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
    13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
    14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
    15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
    16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
    17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
    18自然语言处理https://want595.blog.csdn.net/article/details/132276591
  • 相关阅读:
    点亮.NET的文字云艺术之光——Sdcb.WordCloud 2.0
    新版TCGAbiolinks包学习03:差异分析
    项目代码标准化
    如何做好测试?(十一)可用性测试 (Usability Testing)
    【PG】PostgreSQL逻辑备份(pg_dump)
    MES系统是如何采集Modbus设备数据的呢?
    小黑子—MyBatis:第二章
    m基于STBC的MIMO通信系统性能仿真和信道容量仿真
    协程的创建
    苹果放出快捷指令专题介绍页面,大大提高了 Mac 使用效率
  • 原文地址:https://blog.csdn.net/m0_68111267/article/details/132243989