如果环境已安装mmclassification,请跳过该步骤。mmclassification框架安装与调试验证请参考博客:mmclassification安装与调试_Coding的叶子的博客-CSDN博客_mmclassification 安装。
mmclassification 的数据集目录主要由标注文件和图片样本组成,其中标注文件存储在meta文件夹中,图片样本存在train、val、test文件夹下,即分别是用于训练、验证和测试的图片样本。图片样本文件按照类别存储在train、val、test文件夹下,同一类别图片存储在同一个子文件夹中,子文件夹的名称为图片所属类别名称。
meta文件夹中主要包含了train.txt、val.txt和test.txt文件。txt文件中的每一行分别存储了图片样本路径和类别id,如下图所示。
如果没有meta标注文件,请参考博客:mmclassification 标注文件生成_Coding的叶子的博客-CSDN博客,生成meta文件夹及其文件夹下的txt文件。
本文示例数据来源于minist手写字体可视化数据集,已按照train、test文件夹进行存储,下载地址为:minist手写数字可视化数据集-深度学习文档类资源-CSDN下载。
将下载的数据集文件夹名称重名为Minist,并且mmclassification工程目录下新建data文件夹,将数据集放到data文件夹下即可。数据集的存储路径不限,需要在下方3.3节中配置相应的路径即可。
在mmclassification工程目录下的mmcls/datasets/新建mydataset.py文件,自定义数据加载类MyDataset,文件名称mydataset和类名称MyDataset可以自行更改。mydataset.py文件中的内容如下:
- # -*- coding: utf-8 -*-
- """
- 乐乐感知学堂公众号
- @author: https://blog.csdn.net/suiyingy
- """
-
- import numpy as np
-
- from .builder import DATASETS
- from .base_dataset import BaseDataset
-
-
-
- @DATASETS.register_module()
- class MyDataset(BaseDataset):
-
- def load_annotations(self):
- assert isinstance(self.ann_file, str)
-
- data_infos = []
- with open(self.ann_file) as f:
- samples = [x.strip().split(' ') for x in f.readlines()]
- for filename, gt_label in samples:
- info = {'img_prefix': self.data_prefix}
- info['img_info'] = {'filename': filename}
- info['gt_label'] = np.array(gt_label, dtype=np.int64)
- data_infos.append(info)
- return data_infos
在mmcls/datasets/__init__.py文件中增加上面定义的类MyDataset,如下图所示:
在mmclassification工程目录configs/_base_/datasets/文件夹下,新建mydataset.py文件,主要用于设置数据集类型、数据增强方式、batch size (samples_per_gpu)、数据集路径和标注文件路径、模型保存周期(interval)。文件内容如下所示:
- # -*- coding: utf-8 -*-
- """
- 乐乐感知学堂公众号
- @author: https://blog.csdn.net/suiyingy
- """
- dataset_type = 'MyDataset'
- classes = ['cat', 'bird', 'dog'] # The category names of your dataset
-
- img_norm_cfg = dict(
- mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
-
- train_pipeline = [
- dict(type='LoadImageFromFile'),
- dict(type='RandomResizedCrop', size=224),
- dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
- dict(type='Normalize', **img_norm_cfg),
- dict(type='ImageToTensor', keys=['img']),
- dict(type='ToTensor', keys=['gt_label']),
- dict(type='Collect', keys=['img', 'gt_label'])
- ]
- test_pipeline = [
- dict(type='LoadImageFromFile'),
- dict(type='Resize', size=(256, -1)),
- dict(type='CenterCrop', crop_size=224),
- dict(type='Normalize', **img_norm_cfg),
- dict(type='ImageToTensor', keys=['img']),
- dict(type='Collect', keys=['img'])
- ]
-
- data = dict(
- train=dict(
- type=dataset_type,
- data_prefix='data/Minist/train',
- ann_file='data/Minist/meta/train.txt',
- classes=classes,
- pipeline=train_pipeline
- ),
- val=dict(
- type=dataset_type,
- data_prefix='data/Minist/test',
- ann_file='data/Minist/meta/test.txt',
- classes=classes,
- pipeline=test_pipeline
- ),
- test=dict(
- type=dataset_type,
- data_prefix='data/Minist/test',
- ann_file='data/Minist/meta/test.txt',
- classes=classes,
- pipeline=test_pipeline
- )
- )
- evaluation = dict(interval=1, metric='accuracy')
以configs/resnet/resnet18_8xb16_cifar10.py配置文件为例,mmclassification的配置文件通常包含以下4个部分:
- _base_ = [
- '../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
- '../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
- ]
../_base_/models/resnet18_cifar.py:定义模型参数,主要包括主干网络、neck、head和类别数量。
../_base_/datasets/cifar10_bs16.py:定义数据集增强方式和路径,也就是3.3节的配置文件,bs16表示batch size为16,即samples_per_gpu=16。
../_base_/schedules/cifar10_bs128.py:定义训练参数,主要包括优化器、学习率、训练总epoch数量。
../_base_/default_runtime.py:定义运行参数,主要包括模型保存周期、日志输出周期等。
configs主要修改的地方为数据配置文件,即把 '../_base_/datasets/cifar10_bs16.py'更换成3.3节中的配置文件'../_base_/datasets/mydataset.py'。即:
mmcls基本的训练命令为:
python tools/train.py 模型配置文件
示例:
python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py
这里已经把resnet18_8xb16_cifar10.py文件按照第4节进行了修改。
更多三维、二维感知算法和金融量化分析算法请关注“乐乐感知学堂”微信公众号,并将持续进行更新。