• 034、test


    之——全纪录

    目录

    之——全纪录

    杂谈

    正文

    1.下载处理数据

    2.数据集概览

    3.构建自定义dataset

    4.初始化网络

    5.训练


    杂谈

            综合方法试一下。


    leaves

    1.下载处理数据

            从官网下载数据集:Classify Leaves | Kaggle

            解压后有一个图片集,一个提交示例,一个测试集,一个训练集。

            images,27153个树叶图片:

            test.csv,8800个:

            train.csv,18353个:


    2.数据集概览

            训练集、测试集、类别:

    1. #导包
    2. import random
    3. import torch
    4. from torch import nn
    5. from torch.nn import functional as F
    6. from torchvision import datasets, transforms
    7. import torchvision
    8. import pandas as pd
    9. import matplotlib.pyplot as plt
    10. from d2l import torch as d2l
    11. from PIL import Image
    12. train_data=pd.read_csv(r"D:\apycharmblackhorse\leaves\train.csv")
    13. test_data=pd.read_csv(r"D:\apycharmblackhorse/leaves/test.csv")
    14. train_images=train_data.iloc[:,0].values #把所有的训练集图片路径读进来成list
    15. print("训练集数量:",len(train_images))
    16. n_train=len(train_images)
    17. test_images=test_data.iloc[:,0].values
    18. print("测试集数量:",len(test_images))
    19. n_test=len(test_images)
    20. train_labels = pd.get_dummies(train_data.iloc[:, 1]).values.astype(int).argmax(1)
    21. #独热编码后找到每行最大的索引记下来就是类别号,而顺序与独热编码colums,也就是与下方排序一致
    22. # print(len(train_labels),train_labels)
    23. #记录并排序所有的类别名
    24. train_labels_header = pd.get_dummies(train_data.iloc[:, 1]).columns.values
    25. print("总类别:",len(train_labels_header))
    26. classes=len(train_labels_header)


    3.构建自定义dataset

           继承 torch.utils.Dataset 类,自定义树叶分类数据集:

    1. #继承 torch.utils.Dataset 类,自定义树叶分类数据集
    2. class leaves_dataset(torch.utils.data.Dataset):
    3. #root数据目录, images图片路径, labels图片标签, transform数据增强
    4. def __init__(self, root, images, labels, transform):
    5. super(leaves_dataset, self).__init__()
    6. self.root = root
    7. self.images = images
    8. if labels is None:
    9. self.labels = None
    10. else:
    11. self.labels = labels
    12. self.transform = transform
    13. #获得指定样本
    14. def __getitem__(self, index):
    15. image_path = self.root + self.images[index]
    16. image = Image.open(image_path)
    17. #预处理
    18. image = self.transform(image)
    19. if self.labels is None:
    20. return image
    21. label = torch.tensor(self.labels[index])
    22. return image, label
    23. #获得数据集长度
    24. def __len__(self):
    25. return self.images.shape[0]

            构建读取数据与预处理:

    1. def load_data(images, labels, batch_size, train):
    2. aug = []
    3. normalize = torchvision.transforms.Normalize(
    4. [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    5. if (train):
    6. aug = [torchvision.transforms.CenterCrop(224),
    7. transforms.RandomHorizontalFlip(),
    8. transforms.RandomVerticalFlip(),
    9. transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    10. transforms.ToTensor(),
    11. normalize]
    12. else:
    13. aug = [torchvision.transforms.Resize([256, 256]),
    14. torchvision.transforms.CenterCrop(224),
    15. transforms.ToTensor(),
    16. normalize]
    17. transform = transforms.Compose(aug)
    18. dataset = leaves_dataset(r"D:\apycharmblackhorse\leaves\\", images, labels, transform=transform)
    19. if train==True:type="训练"
    20. else:type="测试"
    21. print("载入:",dataset.__len__(),type)
    22. return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0, shuffle=train)
    23. train_iter = load_data(train_images, train_labels, 512, train=True)

    4.初始化网络

            使用官方预训练模型初始化网络,并修改输出类别数:

    1. #初始化网络
    2. net = torchvision.models.resnet18(pretrained=True)
    3. net.fc = nn.Linear(net.fc.in_features, classes)
    4. nn.init.xavier_uniform_(net.fc.weight)
    5. net.fc


    5.训练

             定义迭代器、优化器以及其他超参数,进行训练:

    1. # 如果param_group=True,输出层中的模型参数将使用十倍的学习率
    2. def train_fine_tuning(net, learning_rate, batch_size=64, num_epochs=20,
    3. param_group=True):
    4. train_slices = random.sample(list(range(n_train)), 15000)
    5. test_slices = list(set(range(n_train)) - set(train_slices))
    6. train_iter = load_data(train_images[train_slices], train_labels[train_slices], batch_size, train=True)
    7. test_iter = load_data(train_images[test_slices], train_labels[test_slices], batch_size, train=False)
    8. devices = d2l.try_all_gpus()
    9. loss = nn.CrossEntropyLoss(reduction="none")
    10. if param_group:
    11. params_1x = [param for name, param in net.named_parameters()
    12. if name not in ["fc.weight", "fc.bias"]]
    13. #别的层不变,最后一层10倍学习率
    14. trainer = torch.optim.Adam([{'params': params_1x},
    15. {'params': net.fc.parameters(),
    16. 'lr': learning_rate * 10}],
    17. lr=learning_rate, weight_decay=0.001)
    18. else:
    19. trainer = torch.optim.Adam(net.parameters(), lr=learning_rate,
    20. weight_decay=0.001)
    21. print(111)
    22. try:
    23. d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
    24. except Exception as e:
    25. print(e)
    26. #%%
    27. #较小的学习率,通过微调预训练获得的模型参数
    28. train_fine_tuning(net, 1e-3)

            小破脑跑得慢,之前不用预训练5个epoch后acc大概只能到0.3  ,使用预训练后到了0.6,但实际上感觉对于树叶的针对性分类还是需要从头开始才是最好的选择,资源不够这里就不做尝试了,大概尝试情况:


    CIFAR-10

    1.数据集


    2.未完待续

  • 相关阅读:
    Vue路由的使用
    基于SSM的高校宿舍管理系统
    【ESD专题】案例:双层SAM卡接口是每个脚都需要静电防护吗?
    MySQL进阶——锁
    7、乐趣国学—趣谈“圣贤”
    代码随想录51——动态规划:309最大买卖股票时机含冷冻期、714买卖股票的最大时机含手续费、300最长递增子序列
    WPSpell将拼写检查添加到VCL应用程序
    SWC 流程
    8-事件组或标志
    error=13, Permission denied
  • 原文地址:https://blog.csdn.net/weixin_55706338/article/details/134460465