• pycocotools库的使用


    1.组成模块

    pycocotools下有三个模块:coco、cocoeval、mask、_mask。

    (1)coco模块: 

    1. # The following API functions are defined:
    2. # COCO - COCO api class that loads COCO annotation file and prepare data structures.
    3. # getAnnIds - Get ann ids that satisfy given filter conditions.
    4. # getCatIds - Get cat ids that satisfy given filter conditions.
    5. # getImgIds - Get img ids that satisfy given filter conditions.
    6. # loadAnns - Load anns with the specified ids.
    7. # loadCats - Load cats with the specified ids.
    8. # loadImgs - Load imgs with the specified ids.
    9. # annToMask - Convert segmentation in an annotation to binary mask.
    10. # showAnns - Display the specified annotations.
    11. # loadRes - Load algorithm results and create API for accessing them.
    12. # download - Download COCO images from mscoco.org server.
    13. # Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
    14. # Help on each functions can be accessed by: "help COCO>function".

    COCO类定义了10个方法:

    (1)获取标注id:

    1. def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
    2. """
    3. Get ann ids that satisfy given filter conditions. default skips that filter
    4. :param imgIds (int array) : get anns for given imgs
    5. catIds (int array) : get anns for given cats
    6. areaRng (float array) : get anns for given area range (e.g. [0 inf])
    7. iscrowd (boolean) : get anns for given crowd label (False or True)
    8. :return: ids (int array) : integer array of ann ids
    9. """

    (2)获取类别id:

    1. def getCatIds(self, catNms=[], supNms=[], catIds=[]):
    2. """
    3. filtering parameters. default skips that filter.
    4. :param catNms (str array) : get cats for given cat names
    5. :param supNms (str array) : get cats for given supercategory names
    6. :param catIds (int array) : get cats for given cat ids
    7. :return: ids (int array) : integer array of cat ids
    8. """

    (3)获取图片id:

    1. def getImgIds(self, imgIds=[], catIds=[]):
    2. '''
    3. Get img ids that satisfy given filter conditions.
    4. :param imgIds (int array) : get imgs for given ids
    5. :param catIds (int array) : get imgs with all given cats
    6. :return: ids (int array) : integer array of img ids
    7. '''

    (4)加载标注信息:

    1. def loadAnns(self, ids=[]):
    2. """
    3. Load anns with the specified ids.
    4. :param ids (int array) : integer ids specifying anns
    5. :return: anns (object array) : loaded ann objects
    6. """

    (5)加载类别:

    1. def loadCats(self, ids=[]):
    2. """
    3. Load cats with the specified ids.
    4. :param ids (int array) : integer ids specifying cats
    5. :return: cats (object array) : loaded cat objects
    6. """

    (6)加载图片:

    1. def loadImgs(self, ids=[]):
    2. """
    3. Load anns with the specified ids.
    4. :param ids (int array) : integer ids specifying img
    5. :return: imgs (object array) : loaded img objects
    6. """

    (7)用matplotlib在图片上显示标注:

    1. def showAnns(self, anns):
    2. """
    3. Display the specified annotations.
    4. :param anns (array of object): annotations to display
    5. :return: None
    6. """

    (8)加载结果文件:

    1. def loadRes(self, resFile):
    2. """
    3. Load result file and return a result api object.
    4. :param resFile (str) : file name of result file
    5. :return: res (obj) : result api object
    6. """

    (9)下载数据集:

    1. def download(self, tarDir = None, imgIds = [] ):
    2. '''
    3. Download COCO images from mscoco.org server.
    4. :param tarDir (str): COCO results directory name
    5. imgIds (list): images to be downloaded
    6. :return:
    7. '''

    (10)ann(polygons, uncompressed RLE)转为rle格式(0表示背景,1表示分割区域):

    1. def annToRLE(self, ann):
    2. """
    3. Convert annotation which can be polygons, uncompressed RLE to RLE.
    4. :return: binary mask (numpy 2D array)
    5. """

    (11)polygons, uncompressed RLE, or RLE 转mask:

    1. def annToMask(self, ann):
    2. """
    3. Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
    4. :return: binary mask (numpy 2D array)
    5. """

     

    2、mask模块下定义了四个函数:

    1. def encode(bimask):
    2. def decode(rleObjs):
    3. def area(rleObjs):
    4. def toBbox(rleObjs):

    3、cocoeval模块定义了COCOeval和Params类:

    1. # The usage for CocoEval is as follows:
    2. # cocoGt=..., cocoDt=... # load dataset and results
    3. # E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
    4. # E.params.recThrs = ...; # set parameters as desired
    5. # E.evaluate(); # run per image evaluation
    6. # E.accumulate(); # accumulate per image results
    7. # E.summarize(); # display summary metrics of results

    2.使用pycocotools加载coco数据集

     1.实例化coco类

    从coco类的源码中,我们可以看到,初始化方法中执行了createIndex方法,其中返回

    字典anns:以标注id为keys,标注信息为values的字典

    字典imgs:以图片id为健,图片信息为值的字典

    字典imgToAnns:以图片id为健,标注信息为值(列表)的字典

    字典cats:以类别id为健,类别信息为值的字典

    字典catToImgs:以种类id为健,图片id(list)的字典

    1. class COCO:
    2. def __init__(self, annotation_file=None):
    3. """
    4. Constructor of Microsoft COCO helper class for reading and visualizing annotations.
    5. :param annotation_file (str): location of annotation file
    6. :param image_folder (str): location to the folder that hosts images.
    7. :return:
    8. """
    9. # load dataset
    10. self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
    11. self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
    12. if not annotation_file == None:
    13. print('loading annotations into memory...')
    14. tic = time.time()
    15. with open(annotation_file, 'r') as f:
    16. dataset = json.load(f)
    17. assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
    18. print('Done (t={:0.2f}s)'.format(time.time()- tic))
    19. self.dataset = dataset
    20. self.createIndex()
    21. def createIndex(self):
    22. # create index
    23. print('creating index...')
    24. anns, cats, imgs = {}, {}, {}
    25. imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
    26. if 'annotations' in self.dataset:
    27. for ann in self.dataset['annotations']:
    28. imgToAnns[ann['image_id']].append(ann)
    29. anns[ann['id']] = ann
    30. if 'images' in self.dataset:
    31. for img in self.dataset['images']:
    32. imgs[img['id']] = img
    33. if 'categories' in self.dataset:
    34. for cat in self.dataset['categories']:
    35. cats[cat['id']] = cat
    36. if 'annotations' in self.dataset and 'categories' in self.dataset:
    37. for ann in self.dataset['annotations']:
    38. catToImgs[ann['category_id']].append(ann['image_id'])
    39. print('index created!')
    40. # create class members
    41. self.anns = anns
    42. self.imgToAnns = imgToAnns
    43. self.catToImgs = catToImgs
    44. self.imgs = imgs
    45. self.cats = cats

     2.读取数据

    我们以tensorflow读取数据为例,pytorch也类似,tensorflow重写的方法keras.utils.Sequence与pytorch需要重写的方法dataloader类似

    1. class COCODetection(Sequence):
    2. def __init__(self, image_path, coco, num_classes, anchors, batch_size, config, COCO_LABEL_MAP={}, augmentation=None):
    3. self.image_path = image_path
    4. self.coco = coco
    5. self.ids = list(self.coco.imgToAnns.keys())
    6. self.num_classes = num_classes
    7. self.anchors = anchors
    8. self.batch_size = batch_size
    9. self.config = config
    10. self.augmentation = augmentation
    11. self.label_map = COCO_LABEL_MAP
    12. self.length = len(self.ids)
    13. def __getitem__(self, index):
    14. for i, global_index in enumerate(range(index * self.batch_size, (index + 1) * self.batch_size)):
    15. global_index = global_index % self.length
    16. image, boxes, mask_gt, num_crowds, image_id = self.pull_item(global_index)
    17. #------------------------------#
    18. # 获得种类
    19. #------------------------------#
    20. class_ids = boxes[:, -1]
    21. #------------------------------#
    22. # 获得框的坐标
    23. #------------------------------#
    24. boxes = boxes[:, :-1]
    25. image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
    26. load_image_gt(image, mask_gt, boxes, class_ids, image_id, self.config, use_mini_mask=self.config.USE_MINI_MASK)
    27. #------------------------------#
    28. # 初始化用于训练的内容
    29. #------------------------------#
    30. if i == 0:
    31. batch_image_meta = np.zeros((self.batch_size,) + image_meta.shape, dtype=image_meta.dtype)
    32. batch_rpn_match = np.zeros([self.batch_size, self.anchors.shape[0], 1], dtype=np.int32)
    33. batch_rpn_bbox = np.zeros([self.batch_size, self.config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4], dtype=np.float32)
    34. batch_images = np.zeros((self.batch_size,) + image.shape, dtype=np.float32)
    35. batch_gt_class_ids = np.zeros((self.batch_size, self.config.MAX_GT_INSTANCES), dtype=np.int32)
    36. batch_gt_boxes = np.zeros((self.batch_size, self.config.MAX_GT_INSTANCES, 4), dtype=np.int32)
    37. batch_gt_masks = np.zeros((self.batch_size, gt_masks.shape[0], gt_masks.shape[1], self.config.MAX_GT_INSTANCES), dtype=gt_masks.dtype)
    38. if not np.any(gt_class_ids > 0):
    39. continue
    40. # RPN Targets
    41. rpn_match, rpn_bbox = build_rpn_targets(image.shape, self.anchors, gt_class_ids, gt_boxes, self.config)
    42. #-----------------------------------------------------------------------#
    43. # 如果某张图片里面物体的数量大于最大值的话,则进行筛选,防止过大
    44. #-----------------------------------------------------------------------#
    45. if gt_boxes.shape[0] > self.config.MAX_GT_INSTANCES:
    46. ids = np.random.choice(
    47. np.arange(gt_boxes.shape[0]), self.config.MAX_GT_INSTANCES, replace=False)
    48. gt_class_ids = gt_class_ids[ids]
    49. gt_boxes = gt_boxes[ids]
    50. gt_masks = gt_masks[:, :, ids]
    51. #------------------------------#
    52. # 将当前信息加载进batch
    53. #------------------------------#
    54. batch_image_meta[i] = image_meta
    55. batch_rpn_match[i] = rpn_match[:, np.newaxis]
    56. batch_rpn_bbox[i] = rpn_bbox
    57. batch_images[i] = preprocess_input(image.astype(np.float32))
    58. batch_gt_class_ids[i, :gt_class_ids.shape[0]] = gt_class_ids
    59. batch_gt_boxes[i, :gt_boxes.shape[0]] = gt_boxes
    60. batch_gt_masks[i, :, :, :gt_masks.shape[-1]] = gt_masks
    61. return [batch_images, batch_image_meta, batch_rpn_match, batch_rpn_bbox, batch_gt_class_ids, batch_gt_boxes, batch_gt_masks], \
    62. [np.zeros(self.batch_size), np.zeros(self.batch_size), np.zeros(self.batch_size), np.zeros(self.batch_size), np.zeros(self.batch_size)]
    63. def __len__(self):
    64. return math.ceil(len(self.ids) / float(self.batch_size))
    65. def pull_item(self, index):
    66. #------------------------------#
    67. # 载入coco序号
    68. # 根据coco序号载入目标信息
    69. #------------------------------#
    70. image_id = self.ids[index]
    71. target = self.coco.loadAnns(self.coco.getAnnIds(imgIds = image_id))
    72. #------------------------------#
    73. # 根据目标信息判断是否为
    74. # iscrowd
    75. #------------------------------#
    76. target = [x for x in target if not ('iscrowd' in x and x['iscrowd'])]
    77. crowd = [x for x in target if ('iscrowd' in x and x['iscrowd'])]
    78. num_crowds = len(crowd)
    79. #------------------------------#
    80. # 将不是iscrowd的目标
    81. # 是iscrowd的目标进行堆叠
    82. #------------------------------#
    83. target += crowd
    84. image_path = osp.join(self.image_path, self.coco.loadImgs(image_id)[0]['file_name'])
    85. image = Image.open(image_path)
    86. image = cvtColor(image)
    87. image = np.array(image, np.float32)
    88. height, width, _ = image.shape
    89. if len(target) > 0:
    90. masks = np.array([self.coco.annToMask(obj).reshape(-1) for obj in target], np.float32)
    91. masks = masks.reshape((-1, height, width))
    92. boxes_classes = []
    93. for obj in target:
    94. bbox = obj['bbox']
    95. final_box = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3], self.label_map[obj['category_id']]]
    96. boxes_classes.append(final_box)
    97. boxes_classes = np.array(boxes_classes, np.float32)
    98. boxes_classes[:, [0, 2]] /= width
    99. boxes_classes[:, [1, 3]] /= height
    100. if self.augmentation is not None:
    101. if len(boxes_classes) > 0:
    102. image, masks, boxes, labels = self.augmentation(image, masks, boxes_classes[:, :4], {'num_crowds': num_crowds, 'labels': boxes_classes[:, 4]})
    103. num_crowds = labels['num_crowds']
    104. labels = labels['labels']
    105. if num_crowds > 0:
    106. labels[-num_crowds:] = -1
    107. boxes = np.concatenate([boxes, np.expand_dims(labels, axis=1)], -1)
    108. masks = np.transpose(masks, [1, 2, 0])
    109. outboxes = np.zeros_like(boxes)
    110. outboxes[:, [0, 2]] = boxes[:, [1, 3]] * self.config.IMAGE_SHAPE[0]
    111. outboxes[:, [1, 3]] = boxes[:, [0, 2]] * self.config.IMAGE_SHAPE[1]
    112. outboxes[:, -1] = boxes[:, -1]
    113. outboxes = np.array(outboxes, np.int)
    114. return image, outboxes, masks, num_crowds, image_id

     

  • 相关阅读:
    再理解springboot那些注册与回调、监控与统计等命名规范,就可以读懂70%的springboot源代码
    让历史文化“活”起来,北京河图“万象中轴”助力打造北京城市金名片
    Java_自定义实体类的列表List<T>调用remove()失败讲解
    Denpendcy Injection 8.0新功能——KeyedService
    【微信小程序】自定义组件(一)
    基于Rspack实现大仓应用构建提效实践|得物技术
    SQL: MAX Function
    【AI工程】08-MLOps工具-在Charmed Kubeflow上运行MindSpore
    免费调用快递鸟物流跟踪轨迹订阅接口技术文档
    ubuntu22.04.1 新装后的常用设置
  • 原文地址:https://blog.csdn.net/qq_52053775/article/details/126197172