• Pytorch深度学习——优化算法、数据集类、数据加载器 05(未完)


    1 常见的优化算法

    1.1 梯度下降算法(BGD)

    每次迭代都需要把所有样本都送入,这样的好处是每次迭代都顾及了全部的样本,做的是全局最优化。

    1.2 随机梯度下降(SGD)

    针对梯度下降训练速度过慢的缺点,提出了随机梯度下降。

    随机梯度下降的算法是从样本中随机抽取一组,训练之后按梯度重新更新一次,然后再抽取一次,再更新一次。
    在torch中的API为:torch.optim.SGD()

    随机梯度下降是把所有样本分成了多个批次,而这些批次之间是没有交叉的,所以它的运算速度快,(刚开始快,后面可能会变的很慢)但是由于每一步梯度下降都不是全局最优,因此可能会陷入局部最优(在最优解附近移动),且不适用并行计算。

    1.3 小批量梯度下降(MBGD)

    这一种方式结合了上面两种的优点,小批量梯度下降的话:每次从样本中随机抽取一小批进行训练。(每次抽到的数据是可以交叉的,而且会覆盖所有的样本数据)

    效果是在1.1 和1.2 之间。

    1.4 动量法

    考虑到mini-batch SGD 在到达最优点的时候,可能并不是真正到达最优点,而是在最优点附近徘徊。

    mini-batch SGD需要我们挑选一个合适的学习率,因此也比较有难度。

    所以,Momentum 优化器刚好可以解决问题,它主要是基于梯度的移动指数加权平均,对网络的参数进行平滑处理,让梯度的摆动幅度变得最小。

    v = 0.8 v + 0.2 ∇ w , ∇ w 表示前一次的梯度 w = w − α v , α v 表示学习率 v = 0.8v + 0.2\nabla w , \nabla w表示前一次的梯度 \\w = w-\alpha v , \alpha v 表示学习率 v=0.8v+0.2∇w,w表示前一次的梯度w=wαv,αv表示学习率

    1.5 AdaGrad

    • 重点是自适应学习率, 在迭代次数不断增大,学习率是不断减小的。

    AdaGrad 算法就是将每一个参数的每一次迭代的梯度,取平方,累加后再开方,用全局学习率除以这个数,作为学习率的动态更新,从而达到自适应学习率的效果。

    在这里插入图片描述

    1.6 RMSProp

    • 是动量法的优化,对学习率进行加权。

    动量法是初步解决了优化中摆动幅度大的问题,为了进一步优化损失函数在更新中存在摆动幅度过大的问题,并且进一步加快函数的收敛速度。

    RMSProp算法对参数的梯度使用了平方加权平均数。
    在这里插入图片描述

    1.7 Adam

    Adam算法是将Momentum算法和RMSProp算法结合起来的算法。

    能够防止梯度的摆幅过大,同时还能够加快收敛速度。

    在这里插入图片描述
    torch中的API为:

    torch.optim.Adam()
    

    2 Pytorch中的数据加载

    • 这一块内容是比较重要的。

    在深度学习项目中,数据量通常是比较大的,面对大量的数据,是不可能一次性的在模型中进行向前的计算和反向传播的

    • 要求:而是需要对数据进行预处理,然后随机打乱整个数据集,把数据处理成一个一个batchs喂给模型。

    2.1 Dataset基类介绍

    在torch中提供了数据集的基类:torch.utils.data.Dataset

    查看一下源码:

    class Dataset(Generic[T_co]):
    
        functions: Dict[str, Callable] = {}
    
        def __getitem__(self, index) -> T_co:
            raise NotImplementedError
    
        def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
            return ConcatDataset([self, other])
    
    

    总结:

    在自定义的数据集类中,继承Dataset类,同时需要实现两个方法
    __len__ : 能够实现通过全局的len() 方法获取其中的元素个数;
    __getitem__ : 能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据。

    2.2 数据集的案例

    数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

    数据介绍:SMS Spam Collection是用于骚扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。
    正常短信和骚扰短信保存在一个文本文件中。
    每行完整记录一条短信内容, 每行开头通过hamspam标识正常短信和骚扰短信。

    数据实例:
    在这里插入图片描述
    实现如下:

    注意:
    在python的字符串中,有一个比较有用的方式,就是在字符串之前加r
    尤其是在写路径的时候,比如说:path = "E:\study_self\LearnPytorch\practice\dataset", ‘’ 就会被当成转义字符,加r变成:r"E:\study_self\LearnPytorch\practice\dataset" 就告诉编译器这是一个原始字符串,在原始字符串中,是直接按照字面意思来使用字符串,没有转移字符、特殊字符或者其他不能打印的字符。

    import torch
    from torch.utils.data import Dataset
    
    datapath = r"E:\study_self\LearnPytorch\practice\dataset\SMSSpamCollection"
    
    
    # 完成数据集类
    class MyDataset(Dataset):
        def __init__(self):
            self.lines = open(datapath, encoding='UTF-8').readlines()
    
        def __getitem__(self, index):
            # 获取索引对应位置的一条数据
            return self.lines[index]
    
        def __len__(self):
            # 返回数据的总数量
            return len(self.lines)
    
    
    if __name__ == '__main__':
        my_dataset = MyDataset()
        print(my_dataset[0])
        print(len(my_dataset))
    
    
    

    在这里插入图片描述

    2.2.1 数据加载器类

    上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:
    (1)批处理数据(batching the data)
    (2)打乱数据 (shuffling the data)
    (3)使用多线程 multiprocessing 并行加载数据

    3 Pytorch自带的数据集

    Pytorch自带的数据集由两个上层API提供,分别是torchvisiontorchtext

    1. torchvision 提供了对图片数据处理相关的API和数据
      例如: torchvision.datasets.MNIST(手写数字的图片数据)【继承自Dataset, 就是一个封装好了的Dataloader

    2. torchtext 提供了对文本数据处理相关的API和数据
      例如:torchtext.datasets.IMDB (电影 评论文本数据)

    3.1 MNIST API中的参数需要注意一下

    torchvision.datasets.MNIST(root='/files', train=True, download=True, transform=)
    
    1. root 参数表示数据存放的位置
    2. train 表示获取的是训练集还是测试集
    3. download : bool类型,表示是否需要下载数据到root目录
    4. transform 实现的对图片的处理函数(比如说打乱等等)

    3.2 MNIST数据集的介绍

    6万个训练,1万个测试,都是黑白图像,像素是28*28

    import torchvision
    
    
    dataset = torchvision.datasets.MNIST(root='E:\study_self\LearnPytorch\dataset\mnist', train=True, download=False, transform=None)
    # print(dataset)
    
    print(dataset[0])
    
    img = dataset[0][0]
    img.show()
    
    

    可以看出,返回值是(图片,目标值)

    在这里插入图片描述

    (<PIL.Image.Image image mode=L size=28x28 at 0x1DB6A802C70>, 5)
    
    Process finished with exit code 0
    
    
  • 相关阅读:
    云原生之K8S------k8s资源限制及探针检查
    QT 使用mysql
    node.js学习之模块化、npm
    基于Java+SpringBoot+LayUI仓库管理系统
    前端开发技术栈(工具篇):详细介绍npm、pnpm和cnpm分别是什么,使用方法以及之间有哪些关系
    MySQL性能优化
    【广州华锐互动】智能楼宇3D数字化展示,实现对建筑物的实时监控和管理
    php实战案例记录(10)单引号和双引号的用法和区别
    Vue中组件间的传值(子传父,父传子)
    Unity脚本判断场景内物体是否为Root Prefab的方法
  • 原文地址:https://blog.csdn.net/weixin_42521185/article/details/126909376