• 深度学习基础知识数据 数据预处理transforms流程讲解


    数据增强可以增加训练集的样本数量,缓解过拟合,并提高模型的泛化能力,从而有效提升算法的性能
    在这里插入图片描述

    1、数据预处理

    在这里插入图片描述

    2、使用节点

    在这里插入图片描述

    2、transform.RandomResizedCrop 随机尺寸裁剪缩放

    在这里插入图片描述

    3、水平翻转 与 垂直翻转

    在这里插入图片描述

    4、ColorJitter变换

    在这里插入图片描述

    5、ToTensor

    在这里插入图片描述

    6、Normalization 归一化

    在这里插入图片描述
    在这里插入图片描述

    7、transforms.Compose

    在这里插入图片描述

    8、重写transforms

    1、分类任务

    在这里插入图片描述

    2、目标检测任务

    重写 transforms 的目的,接受多个参数,并对图像 和 标注做同步处理
    下面以SSD目标检测项目中的重写transforms方法为例:
    在这里插入图片描述
    重写transforms.py代码文件

    import random
    
    import torch
    import torchvision.transforms as t
    from torchvision.transforms import functional as F
    
    from src import dboxes300_coco, calc_iou_tensor, Encoder
    
    
    class Compose(object):
        """组合多个transform函数"""
        def __init__(self, transforms):
            self.transforms = transforms
    
        def __call__(self, image, target=None):
            for trans in self.transforms:
                image, target = trans(image, target)
            return image, target
    
    
    class ToTensor(object):
        """将PIL图像转为Tensor"""
        def __call__(self, image, target):
            image = F.to_tensor(image).contiguous()
            return image, target
    
    
    class RandomHorizontalFlip(object):
        """随机水平翻转图像以及bboxes,该方法应放在ToTensor后"""
        def __init__(self, prob=0.5):
            self.prob = prob
    
        def __call__(self, image, target):
            if random.random() < self.prob:
                # height, width = image.shape[-2:]
                image = image.flip(-1)  # 水平翻转图片
                bbox = target["boxes"]
                # bbox: xmin, ymin, xmax, ymax
                # bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
                bbox[:, [0, 2]] = 1.0 - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
                target["boxes"] = bbox
            return image, target
    
    
    # This function is from https://github.com/chauhan-utk/ssd.DomainAdaptation.
    class SSDCropping(object):
        """
        根据原文,对图像进行裁剪,该方法应放在ToTensor前
        Cropping for SSD, according to original paper
        Choose between following 3 conditions:
        1. Preserve the original image
        2. Random crop minimum IoU is among 0.1, 0.3, 0.5, 0.7, 0.9
        3. Random crop
        Reference to https://github.com/chauhan-utk/src.DomainAdaptation
        """
        def __init__(self):
            self.sample_options = (
                # Do nothing
                None,
                # min IoU, max IoU
                (0.1, None),
                (0.3, None),
                (0.5, None),
                (0.7, None),
                (0.9, None),
                # no IoU requirements
                (None, None),
            )
            self.dboxes = dboxes300_coco()
    
        def __call__(self, image, target):
            # Ensure always return cropped image
            while True:
                mode = random.choice(self.sample_options)
                if mode is None:  # 不做随机裁剪处理
                    return image, target
    
                htot, wtot = target['height_width']
    
                min_iou, max_iou = mode
                min_iou = float('-inf') if min_iou is None else min_iou
                max_iou = float('+inf') if max_iou is None else max_iou
    
                # Implementation use 5 iteration to find possible candidate
                for _ in range(5):
                    # 0.3*0.3 approx. 0.1
                    w = random.uniform(0.3, 1.0)
                    h = random.uniform(0.3, 1.0)
    
                    if w/h < 0.5 or w/h > 2:  # 保证宽高比例在0.5-2之间
                        continue
    
                    # left 0 ~ wtot - w, top 0 ~ htot - h
                    left = random.uniform(0, 1.0 - w)
                    top = random.uniform(0, 1.0 - h)
    
                    right = left + w
                    bottom = top + h
    
                    # boxes的坐标是在0-1之间的
                    bboxes = target["boxes"]
                    ious = calc_iou_tensor(bboxes, torch.tensor([[left, top, right, bottom]]))
    
                    # tailor all the bboxes and return
                    # all(): Returns True if all elements in the tensor are True, False otherwise.
                    if not ((ious > min_iou) & (ious < max_iou)).all():
                        continue
    
                    # discard any bboxes whose center not in the cropped image
                    xc = 0.5 * (bboxes[:, 0] + bboxes[:, 2])
                    yc = 0.5 * (bboxes[:, 1] + bboxes[:, 3])
    
                    # 查找所有的gt box的中心点有没有在采样patch中的
                    masks = (xc > left) & (xc < right) & (yc > top) & (yc < bottom)
    
                    # if no such boxes, continue searching again
                    # 如果所有的gt box的中心点都不在采样的patch中,则重新找
                    if not masks.any():
                        continue
    
                    # 修改采样patch中的所有gt box的坐标(防止出现越界的情况)
                    bboxes[bboxes[:, 0] < left, 0] = left
                    bboxes[bboxes[:, 1] < top, 1] = top
                    bboxes[bboxes[:, 2] > right, 2] = right
                    bboxes[bboxes[:, 3] > bottom, 3] = bottom
    
                    # 虑除不在采样patch中的gt box
                    bboxes = bboxes[masks, :]
                    # 获取在采样patch中的gt box的标签
                    labels = target['labels']
                    labels = labels[masks]
    
                    # 裁剪patch
                    left_idx = int(left * wtot)
                    top_idx = int(top * htot)
                    right_idx = int(right * wtot)
                    bottom_idx = int(bottom * htot)
                    image = image.crop((left_idx, top_idx, right_idx, bottom_idx))
    
                    # 调整裁剪后的bboxes坐标信息
                    bboxes[:, 0] = (bboxes[:, 0] - left) / w
                    bboxes[:, 1] = (bboxes[:, 1] - top) / h
                    bboxes[:, 2] = (bboxes[:, 2] - left) / w
                    bboxes[:, 3] = (bboxes[:, 3] - top) / h
    
                    # 更新crop后的gt box坐标信息以及标签信息
                    target['boxes'] = bboxes
                    target['labels'] = labels
    
                    return image, target
    
    
    class Resize(object):
        """对图像进行resize处理,该方法应放在ToTensor前"""
        def __init__(self, size=(300, 300)):
            self.resize = t.Resize(size)
    
        def __call__(self, image, target):
            image = self.resize(image)
            return image, target
    
    
    class ColorJitter(object):
        """对图像颜色信息进行随机调整,该方法应放在ToTensor前"""
        def __init__(self, brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05):
            self.trans = t.ColorJitter(brightness, contrast, saturation, hue)
    
        def __call__(self, image, target):
            image = self.trans(image)
            return image, target
    
    
    class Normalization(object):
        """对图像标准化处理,该方法应放在ToTensor后"""
        def __init__(self, mean=None, std=None):
            if mean is None:
                mean = [0.485, 0.456, 0.406]
            if std is None:
                std = [0.229, 0.224, 0.225]
            self.normalize = t.Normalize(mean=mean, std=std)
    
        def __call__(self, image, target):
            image = self.normalize(image)
            return image, target
    
    
    class AssignGTtoDefaultBox(object):
        """将DefaultBox与GT进行匹配"""
        def __init__(self):
            self.default_box = dboxes300_coco()
            self.encoder = Encoder(self.default_box)
    
        def __call__(self, image, target):
            boxes = target['boxes']
            labels = target["labels"]
            # bboxes_out (Tensor 8732 x 4), labels_out (Tensor 8732)
            bboxes_out, labels_out = self.encoder.encode(boxes, labels)
            target['boxes'] = bboxes_out
            target['labels'] = labels_out
    
            return image, target
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201

    重写的dataset类文件代码如下:

    from torch.utils.data import Dataset
    import os
    import torch
    import json
    from PIL import Image
    from lxml import etree
    
    
    class VOCDataSet(Dataset):
        """读取解析PASCAL VOC2007/2012数据集"""
    
        def __init__(self, voc_root, year="2012", transforms=None, train_set='train.txt'):
            assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
            # 增加容错能力
            if "VOCdevkit" in voc_root:
                self.root = os.path.join(voc_root, f"VOC{year}")
            else:
                self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
            self.img_root = os.path.join(self.root, "JPEGImages")
            self.annotations_root = os.path.join(self.root, "Annotations")
    
            txt_list = os.path.join(self.root, "ImageSets", "Main", train_set)
    
            with open(txt_list) as read:
                self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                                 for line in read.readlines() if len(line.strip()) > 0]
    
            # read class_indict
            json_file = "./pascal_voc_classes.json"
            assert os.path.exists(json_file), "{} file not exist.".format(json_file)
            with open(json_file, 'r') as f:
                self.class_dict = json.load(f)
    
            self.transforms = transforms
    
        def __len__(self):
            return len(self.xml_list)
    
        def __getitem__(self, idx):
            # read xml
            xml_path = self.xml_list[idx]
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            data_height = int(data["size"]["height"])
            data_width = int(data["size"]["width"])
            height_width = [data_height, data_width]
            img_path = os.path.join(self.img_root, data["filename"])
            image = Image.open(img_path)
            if image.format != "JPEG":
                raise ValueError("Image '{}' format not JPEG".format(img_path))
    
            assert "object" in data, "{} lack of object information.".format(xml_path)
            boxes = []
            labels = []
            iscrowd = []
            for obj in data["object"]:
                # 将所有的gt box信息转换成相对值0-1之间
                xmin = float(obj["bndbox"]["xmin"]) / data_width
                xmax = float(obj["bndbox"]["xmax"]) / data_width
                ymin = float(obj["bndbox"]["ymin"]) / data_height
                ymax = float(obj["bndbox"]["ymax"]) / data_height
    
                # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
                if xmax <= xmin or ymax <= ymin:
                    print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                    continue
                    
                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(self.class_dict[obj["name"]])
                if "difficult" in obj:
                    iscrowd.append(int(obj["difficult"]))
                else:
                    iscrowd.append(0)
    
            # convert everything into a torch.Tensor
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
            height_width = torch.as_tensor(height_width, dtype=torch.int64)
            image_id = torch.tensor([idx])
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
    
            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["image_id"] = image_id
            target["area"] = area
            target["iscrowd"] = iscrowd
            target["height_width"] = height_width
    
            if self.transforms is not None:
                image, target = self.transforms(image, target)
    
            return image, target
    
        def get_height_and_width(self, idx):
            # read xml
            xml_path = self.xml_list[idx]
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            data_height = int(data["size"]["height"])
            data_width = int(data["size"]["width"])
            return data_height, data_width
    
        def parse_xml_to_dict(self, xml):
            """
            将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
            Args:
                xml: xml tree obtained by parsing XML file contents using lxml.etree
    
            Returns:
                Python dictionary holding XML contents.
            """
    
            if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
                return {xml.tag: xml.text}
    
            result = {}
            for child in xml:
                child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
                if child.tag != 'object':
                    result[child.tag] = child_result[child.tag]
                else:
                    if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                        result[child.tag] = []
                    result[child.tag].append(child_result[child.tag])
            return {xml.tag: result}
    
        def coco_index(self, idx):
            """
            该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理
            由于不用去读取图片,可大幅缩减统计时间
    
            Args:
                idx: 输入需要获取图像的索引
            """
            # read xml
            xml_path = self.xml_list[idx]
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            data_height = int(data["size"]["height"])
            data_width = int(data["size"]["width"])
            height_width = [data_height, data_width]
            # img_path = os.path.join(self.img_root, data["filename"])
            # image = Image.open(img_path)
            # if image.format != "JPEG":
            #     raise ValueError("Image format not JPEG")
            boxes = []
            labels = []
            iscrowd = []
            for obj in data["object"]:
                # 将所有的gt box信息转换成相对值0-1之间
                xmin = float(obj["bndbox"]["xmin"]) / data_width
                xmax = float(obj["bndbox"]["xmax"]) / data_width
                ymin = float(obj["bndbox"]["ymin"]) / data_height
                ymax = float(obj["bndbox"]["ymax"]) / data_height
                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(self.class_dict[obj["name"]])
                iscrowd.append(int(obj["difficult"]))
    
            # convert everything into a torch.Tensor
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
            height_width = torch.as_tensor(height_width, dtype=torch.int64)
            image_id = torch.tensor([idx])
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
    
            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["image_id"] = image_id
            target["area"] = area
            target["iscrowd"] = iscrowd
            target["height_width"] = height_width
    
            return target
    
        @staticmethod
        def collate_fn(batch):
            images, targets = tuple(zip(*batch))
            # images = torch.stack(images, dim=0)
            #
            # boxes = []
            # labels = []
            # img_id = []
            # for t in targets:
            #     boxes.append(t['boxes'])
            #     labels.append(t['labels'])
            #     img_id.append(t["image_id"])
            # targets = {"boxes": torch.stack(boxes, dim=0),
            #            "labels": torch.stack(labels, dim=0),
            #            "image_id": torch.as_tensor(img_id)}
    
            return images, targets
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202

    3、分割任务

    在这里插入图片描述
    定义VOCSegmentation类的py文件

    import os
    
    import torch.utils.data as data
    from PIL import Image
    
    
    class VOCSegmentation(data.Dataset):
        def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
            super(VOCSegmentation, self).__init__()
            assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
            root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
            assert os.path.exists(root), "path '{}' does not exist.".format(root)
            image_dir = os.path.join(root, 'JPEGImages')
            mask_dir = os.path.join(root, 'SegmentationClass')
    
            txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
            assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
            with open(os.path.join(txt_path), "r") as f:
                file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]
    
            self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
            self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
            assert (len(self.images) == len(self.masks))
            self.transforms = transforms
    
        def __getitem__(self, index):
            """
            Args:
                index (int): Index
    
            Returns:
                tuple: (image, target) where target is the image segmentation.
            """
            img = Image.open(self.images[index]).convert('RGB')
            target = Image.open(self.masks[index])    # gqr:读取的mask文件是一个单通道的掩膜数据
    
            if self.transforms is not None:
                img, target = self.transforms(img, target)
    
            return img, target
    
        def __len__(self):
            return len(self.images)
    
        @staticmethod
        def collate_fn(batch):
            images, targets = list(zip(*batch))
            batched_imgs = cat_list(images, fill_value=0)   # gqr:将不同尺寸的数据打包桶相同尺寸大小的tensor,有利于加速训练
            batched_targets = cat_list(targets, fill_value=255)   # gqr:将不同尺寸的数据打包桶相同尺寸大小的tensor,有利于加速训练
            return batched_imgs, batched_targets
    
    
    def cat_list(images, fill_value=0):
        # 计算该batch数据中,channel, h, w的最大值
        max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
        batch_shape = (len(images),) + max_size
        batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
        for img, pad_img in zip(images, batched_imgs):
            pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
        return batched_imgs
    
    
    # dataset = VOCSegmentation(voc_root="/data/", transforms=get_transform(train=True))
    # d1 = dataset[0]
    # print(d1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    重写的transforms代码:

    import numpy as np
    import random
    
    import torch
    from torchvision import transforms as T
    from torchvision.transforms import functional as F
    
    
    def pad_if_smaller(img, size, fill=0):
        # 如果图像最小边长小于给定size,则用数值fill进行padding
        min_size = min(img.size)
        if min_size < size:
            ow, oh = img.size
            padh = size - oh if oh < size else 0
            padw = size - ow if ow < size else 0
            img = F.pad(img, (0, 0, padw, padh), fill=fill)
        return img
    
    
    class Compose(object):
        def __init__(self, transforms):
            self.transforms = transforms
    
        def __call__(self, image, target):
            for t in self.transforms:
                image, target = t(image, target)
            return image, target
    
    
    class RandomResize(object):
        def __init__(self, min_size, max_size=None):
            self.min_size = min_size
            if max_size is None:
                max_size = min_size
            self.max_size = max_size
    
        def __call__(self, image, target):
            size = random.randint(self.min_size, self.max_size)
            # 这里size传入的是int类型,所以是将图像的最小边长缩放到size大小
            image = F.resize(image, size)
            # 这里的interpolation注意下,在torchvision(0.9.0)以后才有InterpolationMode.NEAREST
            # 如果是之前的版本需要使用PIL.Image.NEAREST
            target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
            return image, target
    
    
    class RandomHorizontalFlip(object):
        def __init__(self, flip_prob):
            self.flip_prob = flip_prob
    
        def __call__(self, image, target):
            if random.random() < self.flip_prob:
                image = F.hflip(image)
                target = F.hflip(target)
            return image, target
    
    
    class RandomCrop(object):
        def __init__(self, size):
            self.size = size
    
        def __call__(self, image, target):
            image = pad_if_smaller(image, self.size)
            target = pad_if_smaller(target, self.size, fill=255)
            crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
            image = F.crop(image, *crop_params)
            target = F.crop(target, *crop_params)
            return image, target
    
    
    class CenterCrop(object):
        def __init__(self, size):
            self.size = size
    
        def __call__(self, image, target):
            image = F.center_crop(image, self.size)
            target = F.center_crop(target, self.size)
            return image, target
    
    
    class ToTensor(object):
        def __call__(self, image, target):
            image = F.to_tensor(image)
            target = torch.as_tensor(np.array(target), dtype=torch.int64)
            return image, target
    
    
    class Normalize(object):
        def __init__(self, mean, std):
            self.mean = mean
            self.std = std
    
        def __call__(self, image, target):
            image = F.normalize(image, mean=self.mean, std=self.std)
            return image, target
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
  • 相关阅读:
    【GIT版本控制】--协作流程
    CentOS上搭建SVN并自动同步至web目录
    linux系统 系统级日志治理
    【HTML5高级第一篇】Web存储 - cookie、localStorage、sessionStorage
    Canal—增量同步数据组件
    K8S学习之Service实现服务发现原理分析与实践
    前端性能优化的方法
    1319_STM32F103串口BootLoader移植
    【毕业设计源码】基于JAVA的微信小程序直播在线教育平台
    C# 如何设计一个好用的日志库?【架构篇】
  • 原文地址:https://blog.csdn.net/guoqingru0311/article/details/133771598