• OD【1】:自定义Dataset



    前言

    本文以 Pascal VOC 2012 数据集为例,讲解如何自定义一个可以用于目标检测的数据集

    参考 Pytorch 官方提供的样例:Tutorial


    1. Introduction

    训练对象检测、实例分割和人物关键点检测的参考脚本可以轻松支持添加新的自定义数据集。数据集应该继承于标准的 torch.utils.data.Dataset 类,并实现 __len____getitem__ 方法

    • __len__:图片的数量
    • __getitem__:图片及其对应的信息
    • get_height_and_width:获取图像高度和宽度的方法
      • 在多 GPU 训练的时候用到

    2. Dataset

    定义一个继承 Dataset 的类

    from torch.utils.data import Dataset
    
    class VOC2012DataSet(Dataset):		
    
    • 1
    • 2
    • 3

    2.1. __ init __

    	def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
    
    • 1
    • voc_root:训练集所在的根目录
    • transforms:图像预处理
    • txt_name:在后面判断是调用训练集数据还是验证集数据
    		assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
    		# 增加容错能力
    		if "VOCdevkit" in voc_root:
    			self.root = os.path.join(voc_root, f"VOC{year}")
        	else:
            	self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        	self.img_root = os.path.join(self.root, "JPEGImages")
        	self.annotations_root = os.path.join(self.root, "Annotations")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    通过 os.path.join 方法将各个目录与根目录拼接到一起(os.path 可用理解为是路径中的斜杠,可以根据不同的操作系统适应不同方向的斜杠)

    		# read train.txt or val.txt file
    		txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
    		assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
        	with open(txt_path) as read:
    			xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
    						for line in read.readlines() if len(line.strip()) > 0]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    根据传入的信息判断对应读取的是训练集还是验证集,并通过 os.path.join 构建对应的文件路径

    其中train.txtval.txt 中包含的对应的文件的名称,且每一行最后都有一个换行符。通过 for 循环遍历所有文件名,使用 line.strip() 方法,去掉最后的换行符,并给每一个文件名后面添加 .xml 的后缀,得到每一个文件对应的 xml 文件,并将所有 xml 文件保存到一个列表中

    		# read class_indict
    		json_file = '/od/model/faster_rcnn/pascal_voc_classes.json'
    		assert os.path.exists(json_file), "{} file not exist.".format(json_file)
    		with open(json_file, 'r') as f:
    			self.class_dict = json.load(f)
    
    		self.transforms = transforms
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    载入之前设定好的类别名称及其对应的索引信息,并存给 self.class_dict 找个变量

    json 文件内容如下所示:

    {
        "aeroplane": 1,
        "bicycle": 2,
        "bird": 3,
        "boat": 4,
        "bottle": 5,
        "bus": 6,
        "car": 7,
        "cat": 8,
        "chair": 9,
        "cow": 10,
        "diningtable": 11,
        "dog": 12,
        "horse": 13,
        "motorbike": 14,
        "person": 15,
        "pottedplant": 16,
        "sheep": 17,
        "sofa": 18,
        "train": 19,
        "tvmonitor": 20
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    2.2. __ len __

        def __len__(self):
            return len(self.xml_list)
    
    • 1
    • 2

    通过 len() 方法获取文件的个数,变量 xml_list 中存储着所有的 .xml 文件,而一个 .xml 文件就对应着一张图片,所以 len() 方法返回的就是数据集的个数

    2.3. __ getitem __

    	def __getitem__(self, idx):
    
    • 1
    • idx:索引值,通过索引值返回图片及图片对应的信息
            # read xml
            xml_path = self.xml_list[idx]
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            img_path = os.path.join(self.img_root, data["filename"])
            image = Image.open(img_path)
            if image.format != "JPEG":
                raise ValueError("Image '{}' format not JPEG".format(img_path))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    找到 idx 对应的 xml 文件并打开,通过 etree.fromstring() 读取 xml 文件并将 xml 文件中的信息传入方法 parse_xml_to_dict(将 xml 文件解析成字典形式进行存储)

    在使用之前需要判断图片文件的类型是否是 .jpeg 格式(如果使用的是 VOC 数据集的话将不会有什么影响)

    		boxes = []
            labels = []
            iscrowd = []
            assert "object" in data, "{} lack of object information.".format(xml_path)
            for obj in data["object"]:
                xmin = float(obj["bndbox"]["xmin"])
                xmax = float(obj["bndbox"]["xmax"])
                ymin = float(obj["bndbox"]["ymin"])
                ymax = float(obj["bndbox"]["ymax"])
    
                # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
                if xmax <= xmin or ymax <= ymin:
                    print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                    continue
                
                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(self.class_dict[obj["name"]])
                if "difficult" in obj:
                    iscrowd.append(int(obj["difficult"]))
                else:
                    iscrowd.append(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • boxes:保存每一个目标的 bounding box 的信息
      • 将 xmin、xmax、ymin、ymax 转换为 float 类型并添加到 boxes 列表中
    • labels:存储对应目标的索引值(.json 文件中设定的)
      • 通过读入的键值对,传入当前目标的 name 来获取对应的索引值并添加到 labels 列表中
    • iscrowd:存在于 COCO 数据集中(判断目标是否重合)
      • 如果 iscrowd = 0,一般代表是一个单目标,容易检测
      • 对应 difficult 参数
    		# convert everything into a torch.Tensor
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
            image_id = torch.tensor([idx])
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
    
            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["image_id"] = image_id
            target["area"] = area
            target["iscrowd"] = iscrowd
    
            if self.transforms is not None:
                image, target = self.transforms(image, target)
    
            return image, target
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    2.4. parse_xml_to_dict

    将 xml 文件解析成字典形式,参考 tensorflow 的 recursive_parse_xml_to_dict

    简单来说,xml 就是嵌套的字典

    	def parse_xml_to_dict(self, xml):
    		if len(xml) == 0: # 遍历到底层,直接返回 tag 对应的信息
    			return {xml.tag: xml.text}
    		result = {]
    		for child in xml:
    			child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
    			if child.tag != 'object':
    				result[child.tag] = child_result[child.tag]
    			else:
    				if child.tag not in result: # 因为 object 可能有多个,所以需要放入列表里
    					result[child.tag] = []
    				result[child.tag].append(child_result[child.tag])
    		return {xml.tag: result}			
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    在这里插入图片描述

    深度优先遍历

    • 先判断 xml 文件是否在 xml 嵌套关系中的底层
      • 如一开始处于 < annotation > 标签,即 xml 文件的顶层
      • 此时看是否有 xml 文件的子目录
    • 初始化一个空字典
    • 遍历上层标签的所有目录
      • 递归调用该方法,解析下一层的 xml 文件
      • 将得到的结果存入字典
      • 判断当前的子目录的 tag 是否为 object
        • 不为 object 则在字典中存入相应的变量
          • e.g. folder - VOC2012
        • 为 object,则以 object 为 key 创建一个空的 list,将解析得到的 object 的 信息添加到列表中
          在这里插入图片描述
          • 存在很多的 object,不能像其他参数一样通过赋值实现,所以需要一个 list 来存储对应的信息

    2.5. get_height_and_width

    获取图像的高度和宽度

        def get_height_and_width(self, idx):
            # read xml
            xml_path = self.xml_list[idx]
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            data_height = int(data["size"]["height"])
            data_width = int(data["size"]["width"])
            return data_height, data_width
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    总结

    通过以上文件代码,我们可以在训练文件中,可以通过 torch.utils.data.DataLoader 传入数据并进行训练

  • 相关阅读:
    Kubernetes 平面组件 etcd
    vite-plugin-vue-setup-extend和unplugin-auto-import让v3开发更丝滑
    力扣第46天--- 第583题、第72题
    Python中的模块
    2022-11-27阿里云物联网平台 MICROPYTHON记录
    Java多线程之死锁
    QT编程,QT内存管理、信号与槽、
    知识竞赛活动舞台搭建需要多少钱
    为什么价格监测要精确到款式
    【三维目标检测】PointRCNN(一)
  • 原文地址:https://blog.csdn.net/HoraceYan/article/details/126544491