• PyTorch应用实战四:基于PyTorch构建复杂应用


    实验环境

    torch1.8.0+torchvision0.9.0

    import torch
    import torchvision
    print(torch.__version__)
    print(torchvision.__version__)
    
    • 1
    • 2
    • 3
    • 4
    1.8.0
    0.9.0+cpu
    
    • 1
    • 2

    1.PyTorch数据加载

    import torchvision.transforms as tfm
    from PIL import Image
    img = Image.open('volleyball.png')
    img_1 = tfm.RandomCrop(200, padding=50)(img)  #随机裁剪图片
    img_1.show()
    img_1.save('crop.png')
    img_2 = tfm.RandomHorizontalFlip()(img)       #随机水平翻转图片
    img_2.show()
    img_2.save('flip.png')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    1.1 数据预处理

    torchvision.transforms

    transfrom_train = tfm.Compose([
        tfm.RandomCrop(32, padding=4),
        tfm.RandomHorizontalFlip(),   
        tfm.ToTensor(),     #将图片转换为Tensor张量                       
        tfm.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))  #标准化
    ])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    1.2 数据加载

    torch.utils.data

    loader = torch.utils.data.DataLoader(
        datasets, batch_size=32, shuffle=True, sampler=None,
        num_workers=2, collate_fn=None, pin_memory=True, drop_last=False
    )
    
    • 1
    • 2
    • 3
    • 4
    • datasets:传入的数据集,可以是自定义的dataset对象或者torchvision中的预定义数据集对象。
    • batch_size:每个batch中包含的样本数量。
    • shuffle:是否打乱数据集。
    • sampler:样本抽样器,如果指定了sampler,则忽略shuffle参数。
    • num_workers:用于数据加载的子进程数量。
    • collate_fn:对样本进行批处理前的预处理函数,可用于对样本进行排序、padding等操作。
    • pin_memory:是否将数据加载到GPU的显存中。
    • drop_last:如果数据集样本数量不能被batch_size整除,则是否舍弃剩余的不足一个batch的样本。

    2.PyTorch模型搭建

    2.1 经典模型

    torchvision.models

    from torchvision import models
    net1 = models.resnet50()
    net2 = models.resnet50(pretrained=True)
    
    • 1
    • 2
    • 3

    2.2 模型加载与保存

    model.load_state_dict(torch.load('pretrained_weights.pth'))
    torch.save(model.state_dict(), 'model_weights.pth')
    
    • 1
    • 2

    3.PyTorch优化器

    3.1 torch.optim

    optimizer = optim.SGD([       #SGD随机梯度下降算法
        {'params':model.base.parameters()},
        {'params':model.classifier.parameters(), 'lr': 1e-3}
    ], lr=1e-2, momentum=0.9)
    
    • 1
    • 2
    • 3
    • 4
    # 训练过程 
    model = init_model_function()               #模型构建
    optimizer = optim.SomeOptimizer(            #设置优化器
        model.parameters(), lr, mm
    )
    
    for data, label in train_dataloader:
        optimizer.zero_grad()                #前向计算前,清空原有梯度
        output = model(data)                 #前向计算
        loss = loss_function(output, label)  #损失函数
        loss.backward()                      #反向传播 
        optimizer.step()                     #更新参数
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    3.2 学习率调整

    scheduler = optim.lr_scheduler.SomeScheduler(optimizer, *args)
    for epoch in range(epochs):
        train()
        test()
        scheduler.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5

    常见函数

    激活单元类型

    ELUMultiheadAttentionSELUsoftshrinkSoftmin
    HardshrinkPReLUCELUSoftsignSoftmax
    HardtanhReLUGELUTanhSoftmax2d
    LeakyReLUReLU6SigmoidTanhshrinkLogSoftmax
    LogSigmoidRReLUSoftplusThreshold

    损失函数层类型

    L1LossPoissonNLLLossHingeEmbeddingLossCosineEmbeddingLoss
    MSELossKLDivLossMultiLabelMarginLossMultiMarginLoss
    CrossEntropyLossBCELossSmoothL1LossTripletMarginLoss
    CTCLossBCEWithLogitsLossSoftMarginLoss
    NLLLossMarginRankingLossMultiLabelSoftMarginLoss

    优化器类型

    AdadeltaAdamWASGDRprop
    AdagradSparseAdamLBFGSSGD
    AdamAdamaxRMSprop

    变换操作类型

    ComposeRandomAffineRandomOrderResizeToTensor
    CenterCropRandomApplyRandomPerspectiveScaleLambda
    ColorJitterRandomChoiceRandomResizedCropTenCrop
    FiveCropRandomCropRandomRotationLinearTransformation
    GrayscaleRandomGrayscaleRandomSizedCropNormalize
    PadRandomHorizontalFlipRandomVerticalFlipToPILImage

    数据集名称

    MNISTCocoCaptionsCIFAR10Flickr8kUSPS
    FashionMNISTcocoDetectionCIFAR100Flickr30kKinetics400
    KMNISTLSUNSTL10VOCSegmentationHMDB51
    EMNISTImageFolderSVHNVOCDetectionUCF101
    QMNISTDatasetFolderPhotoTourCityscapeCelebA
    FakeDataImageNetSBUSBDataset

    torchvision.models中所有实现的分类模型

    AlexNetVGG-13-bnResNet-101Densenet-201ResNeXt-50-32x4d
    VGG-11VGG-16-bnResNet-152Densenet-161ResNeXt-101-32x8d
    VGG-13VGG-19-bnSqueezeNetInception-V3Wide ResNet-50-2
    VGG-16ResNet-18GoogleNetWide ResNet-101-2
    VGG-19ResNet-34Densenet-121ShuffleNet-V2MNASNet 1.0
    VGG-11-bnResNet-50Densenet-169MobileNet-V2

    附:系列文章

    序号文章目录直达链接
    1PyTorch应用实战一:实现卷积操作https://want595.blog.csdn.net/article/details/132575530
    2PyTorch应用实战二:实现卷积神经网络进行图像分类https://want595.blog.csdn.net/article/details/132575702
    3PyTorch应用实战三:构建神经网络https://want595.blog.csdn.net/article/details/132575758
    4PyTorch应用实战四:基于PyTorch构建复杂应用https://want595.blog.csdn.net/article/details/132625270
    5PyTorch应用实战五:实现二值化神经网络https://want595.blog.csdn.net/article/details/132625348
    6PyTorch应用实战六:利用LSTM实现文本情感分类https://want595.blog.csdn.net/article/details/132625382
  • 相关阅读:
    计算机毕设 深度学习 机器学习 酒店评价情感分析算法实现
    常见的积分:数理方程中常见的复杂积分
    1075 PAT Judge
    百度文心一言GPT免费入口也来了!!!
    【华为机试真题 JAVA】找终点-100
    CCIE理论-IPSec的主模式和野蛮模式的区别
    达梦数据库使用IPV6连接
    Linux--VMware的安装和Centos
    GBase 8d的特性-可扩展性
    Jetpack Compose学习(9)——Compose中的列表控件(LazyRow和LazyColumn)
  • 原文地址:https://blog.csdn.net/m0_68111267/article/details/132625270