• mmclassification 训练自定义数据


    1 mmclassification 安装

            如果环境已安装mmclassification,请跳过该步骤。mmclassification框架安装与调试验证请参考博客:mmclassification安装与调试_Coding的叶子的博客-CSDN博客_mmclassification 安装

    2 数据集准备

            mmclassification 的数据集目录主要由标注文件和图片样本组成,其中标注文件存储在meta文件夹中,图片样本存在train、val、test文件夹下,即分别是用于训练、验证和测试的图片样本。图片样本文件按照类别存储在train、val、test文件夹下,同一类别图片存储在同一个子文件夹中,子文件夹的名称为图片所属类别名称。

            meta文件夹中主要包含了train.txt、val.txt和test.txt文件。txt文件中的每一行分别存储了图片样本路径和类别id,如下图所示。

            如果没有meta标注文件,请参考博客:mmclassification 标注文件生成_Coding的叶子的博客-CSDN博客,生成meta文件夹及其文件夹下的txt文件。

             本文示例数据来源于minist手写字体可视化数据集,已按照train、test文件夹进行存储,下载地址为:minist手写数字可视化数据集-深度学习文档类资源-CSDN下载

            将下载的数据集文件夹名称重名为Minist,并且mmclassification工程目录下新建data文件夹,将数据集放到data文件夹下即可。数据集的存储路径不限,需要在下方3.3节中配置相应的路径即可。

    3 自定义数据集

    3.1 新建MyDataset

            在mmclassification工程目录下的mmcls/datasets/新建mydataset.py文件,自定义数据加载类MyDataset,文件名称mydataset和类名称MyDataset可以自行更改。mydataset.py文件中的内容如下: 

    1. # -*- coding: utf-8 -*-
    2. """
    3. 乐乐感知学堂公众号
    4. @author: https://blog.csdn.net/suiyingy
    5. """
    6. import numpy as np
    7. from .builder import DATASETS
    8. from .base_dataset import BaseDataset
    9. @DATASETS.register_module()
    10. class MyDataset(BaseDataset):
    11. def load_annotations(self):
    12. assert isinstance(self.ann_file, str)
    13. data_infos = []
    14. with open(self.ann_file) as f:
    15. samples = [x.strip().split(' ') for x in f.readlines()]
    16. for filename, gt_label in samples:
    17. info = {'img_prefix': self.data_prefix}
    18. info['img_info'] = {'filename': filename}
    19. info['gt_label'] = np.array(gt_label, dtype=np.int64)
    20. data_infos.append(info)
    21. return data_infos

     3.2 将MyDataset注册到mmclassification框架

            在mmcls/datasets/__init__.py文件中增加上面定义的类MyDataset,如下图所示:

     3.3 新建数据集配置文件

            在mmclassification工程目录configs/_base_/datasets/文件夹下,新建mydataset.py文件,主要用于设置数据集类型、数据增强方式、batch size (samples_per_gpu)、数据集路径和标注文件路径、模型保存周期(interval)。文件内容如下所示:

    1. # -*- coding: utf-8 -*-
    2. """
    3. 乐乐感知学堂公众号
    4. @author: https://blog.csdn.net/suiyingy
    5. """
    6. dataset_type = 'MyDataset'
    7. classes = ['cat', 'bird', 'dog'] # The category names of your dataset
    8. img_norm_cfg = dict(
    9. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
    10. train_pipeline = [
    11. dict(type='LoadImageFromFile'),
    12. dict(type='RandomResizedCrop', size=224),
    13. dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    14. dict(type='Normalize', **img_norm_cfg),
    15. dict(type='ImageToTensor', keys=['img']),
    16. dict(type='ToTensor', keys=['gt_label']),
    17. dict(type='Collect', keys=['img', 'gt_label'])
    18. ]
    19. test_pipeline = [
    20. dict(type='LoadImageFromFile'),
    21. dict(type='Resize', size=(256, -1)),
    22. dict(type='CenterCrop', crop_size=224),
    23. dict(type='Normalize', **img_norm_cfg),
    24. dict(type='ImageToTensor', keys=['img']),
    25. dict(type='Collect', keys=['img'])
    26. ]
    27. data = dict(
    28. train=dict(
    29. type=dataset_type,
    30. data_prefix='data/Minist/train',
    31. ann_file='data/Minist/meta/train.txt',
    32. classes=classes,
    33. pipeline=train_pipeline
    34. ),
    35. val=dict(
    36. type=dataset_type,
    37. data_prefix='data/Minist/test',
    38. ann_file='data/Minist/meta/test.txt',
    39. classes=classes,
    40. pipeline=test_pipeline
    41. ),
    42. test=dict(
    43. type=dataset_type,
    44. data_prefix='data/Minist/test',
    45. ann_file='data/Minist/meta/test.txt',
    46. classes=classes,
    47. pipeline=test_pipeline
    48. )
    49. )
    50. evaluation = dict(interval=1, metric='accuracy')

    4 修改configs模型配置文件

            以configs/resnet/resnet18_8xb16_cifar10.py配置文件为例,mmclassification的配置文件通常包含以下4个部分:

    1. _base_ = [
    2.     '../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
    3.     '../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
    4. ]

            ../_base_/models/resnet18_cifar.py:定义模型参数,主要包括主干网络、neck、head和类别数量。

            ../_base_/datasets/cifar10_bs16.py:定义数据集增强方式和路径,也就是3.3节的配置文件,bs16表示batch size为16,即samples_per_gpu=16。

            ../_base_/schedules/cifar10_bs128.py:定义训练参数,主要包括优化器、学习率、训练总epoch数量。

            ../_base_/default_runtime.py:定义运行参数,主要包括模型保存周期、日志输出周期等。

            configs主要修改的地方为数据配置文件,即把 '../_base_/datasets/cifar10_bs16.py'更换成3.3节中的配置文件'../_base_/datasets/mydataset.py'。即:

    5 运行训练程序

            mmcls基本的训练命令为:

    python tools/train.py 模型配置文件

            示例:

    python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py

            这里已经把resnet18_8xb16_cifar10.py文件按照第4节进行了修改。

    6 运行结果

     【python三维深度学习】python三维点云从基础到深度学习_Coding的叶子的博客-CSDN博客_python 三维点云

    更多三维、二维感知算法和金融量化分析算法请关注“乐乐感知学堂”微信公众号,并将持续进行更新。

  • 相关阅读:
    JavaScript中的短路表达式
    2024HW --->蓝队面试题
    计算机系统(22)----- 管程、死锁
    Python150题day10
    shell脚本按日期范围和间隔下载数据
    ARM GNU汇编代码分析
    使用WebApi+Vue3从0到1搭建《权限管理系统》:二、搭建JWT系统鉴权
    小约翰可汗视频随记
    【AUTOSAR-RTE】-3-Runnable及其Task Mapping映射
    SpringCloud-Bus
  • 原文地址:https://blog.csdn.net/suiyingy/article/details/125551909