• #Python&Pytorch 图片和谐化模型——BargainNet



    前言

    BargainNet是bcmi的一个项目。具体项目介绍见GitHub链接。出于各种原因需要使用BargainNet,因为有些不习惯用命令行启动训练模型,所以将里面使用的默认模型、参数直接提取出来,简化成了简单的“读取数据”和“训练模型”两个文件。


    一、文件结构

    训练数据的文件结构如下(去不掉水印我也很烦):

    IHD_train.txt的结构很简单,就是文件列表而已:
    在这里插入图片描述

    其他的就是读取数据部分的代码和模型代码放在同一文件夹,改一下读取数据代码里数据集的路径就可以

    二、读取数据部分

    文件名为:HarmonyDataset.py,方便模型导入

    1.引入库

    import os.path
    import random
    from abc import ABC
    
    import cv2.cv2 as cv2
    import numpy as np
    import torch
    import torch.utils.data as data
    import torchvision.transforms as transforms
    from albumentations import HorizontalFlip, RandomResizedCrop, Compose, DualTransform, ToGray
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    2.读入数据

    class HCompose(Compose):
        def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs):
            if additional_targets is None:
                additional_targets = {
                    'real': 'image',
                    'mask': 'mask'
                }
            self.additional_targets = additional_targets
            super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs)
            if no_nearest_for_masks:
                for t in transforms:
                    if isinstance(t, DualTransform):
                        t._additional_targets['mask'] = 'image'
    
    
    def get_transform(params=None, no_flip=True, grayscale=False):
        transform_list = []
        if grayscale:
            transform_list.append(ToGray())
        if params is None:
            transform_list.append(RandomResizedCrop(512, 512, scale=(0.5, 1.0)))
    
        if not no_flip:
            if params is None:
                transform_list.append(HorizontalFlip())
    
        return HCompose(transform_list)
    
    
    class Iharmony4Dataset(data.Dataset, ABC):
        def __init__(self, dataset_root,):
            self.image_paths = []
            print('loading training file: ')
            self.keep_background_prob = 0.05
            self.file = dataset_root.replace("com", "") + 'IHD_train.txt'
            with open(self.file, 'r') as f:
                for line in f.readlines():
                    self.image_paths.append(os.path.join(dataset_root, line.rstrip()))
    
            self.transform = get_transform()
            self.input_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
    
        def __getitem__(self, index):
            sample = self.get_sample(index)
            self.check_sample_types(sample)
            sample = self.augment_sample(sample)
            comp = self.input_transform(sample['image'])
            real = self.input_transform(sample['real'])
            mask = sample['mask'].astype(np.float32)
            mask = mask[np.newaxis, ...].astype(np.float32)
            output = {
                'comp': comp.unsqueeze(0),
                'mask': torch.from_numpy(mask).unsqueeze(0),
                'real': real.unsqueeze(0),
                'img_path': sample['img_path']
            }
            return output
    
        def check_sample_types(self, sample):
            assert sample['comp'].dtype == 'uint8'
            if 'real' in sample:
                assert sample['real'].dtype == 'uint8'
    
        def augment_sample(self, sample):
            if self.transform is None:
                return sample
            additional_targets = {target_name: sample[target_name]
                                  for target_name in self.transform.additional_targets.keys()}
    
            valid_augmentation = False
            while not valid_augmentation:
                aug_output = self.transform(image=sample['comp'], **additional_targets)
                valid_augmentation = self.check_augmented_sample(aug_output)
    
            for target_name, transformed_target in aug_output.items():
                sample[target_name] = transformed_target
    
            return sample
    
        def check_augmented_sample(self, aug_output):
            if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob:
                return True
    
            return aug_output['mask'].sum() > 1.0
    
        def get_sample(self, index):
            path = self.image_paths[index]
            name_parts = path.split('_')
            mask_path = self.image_paths[index].replace('com', 'mask')
            mask_path = mask_path.replace(('_' + name_parts[-1]), '.png')
            target_path = self.image_paths[index].replace('com', 'gt')
            target_path = target_path.replace(('_' + name_parts[-1]), '.png')
    
            comp = cv2.imread(path)
            comp = cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
            real = cv2.imread(target_path)
            real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB)
            mask = cv2.imread(mask_path)
            mask = mask[:, :, 0].astype(np.float32) / 255.
            mask = mask.astype(np.uint8)
    
            return {'comp': comp, 'mask': mask, 'real': real, 'img_path': path}
    
        def __len__(self):
            return len(self.image_paths)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109

    comp为合成后的图片—————————— mask为合成区域的mask——————real为groundtrue
    在这里插入图片描述
    目标自然就是让comp -> real了


    三、模型构成

    叫啥都行

    1.引入库

    import functools
    
    import torch
    import torch.nn.functional as F
    import tqdm
    from torch import nn
    from torch.nn import init
    from torch.optim import lr_scheduler
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    2.模型结构——G

    class UnetGenerator(nn.Module):
        def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
                     use_attention=False):
            super(UnetGenerator, self).__init__()
            # construct unet structure
            weight = torch.FloatTensor([0.1])
            self.weight = torch.nn.Parameter(weight, requires_grad=True)
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
                                                 innermost=True)  # add the innermost layer
            for i in range(num_downs - 5):  # add intermediate layers with ngf * 8 filters
                unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
                                                     norm_layer=norm_layer, use_dropout=use_dropout)
            # gradually reduce the number of filters from ngf * 8 to ngf
            unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, use_attention=use_attention)
            unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, use_attention=use_attention)
            unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,
                                                 use_attention=use_attention)
            self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
                                                 norm_layer=norm_layer)  # add the outermost layer
    
        def forward(self, inputs):
            ori_code_map = inputs[:, 4:, :, :]
            code_map_input = ori_code_map * torch.clamp(self.weight, min=0.001)
            mew_inputs = torch.cat([inputs[:, :4, :, :], code_map_input], 1)
            return self.model(mew_inputs)
    
    
    class UnetSkipConnectionBlock(nn.Module):
        def __init__(self, outer_nc, inner_nc, input_nc=None,
                     submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,
                     use_attention=False):
            super(UnetSkipConnectionBlock, self).__init__()
            self.outermost = outermost
            if type(norm_layer) == functools.partial:
                use_bias = norm_layer.func == nn.InstanceNorm2d
            else:
                use_bias = norm_layer == nn.InstanceNorm2d
            if input_nc is None:
                input_nc = outer_nc
            downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                                 stride=2, padding=1, bias=use_bias)
            downrelu = nn.LeakyReLU(0.2, True)
            downnorm = norm_layer(inner_nc)
            uprelu = nn.ReLU(True)
            upnorm = norm_layer(outer_nc)
    
            if outermost:
                upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
                down = [downconv]
                up = [uprelu, upconv, nn.Tanh()]
                model = down + [submodule] + up
            elif innermost:
                upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
                down = [downrelu, downconv]
                up = [uprelu, upconv, upnorm]
                model = down + up
            else:
                upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
                down = [downrelu, downconv, downnorm]
                up = [uprelu, upconv, upnorm]
    
                if use_dropout:
                    model = down + [submodule] + up + [nn.Dropout(0.5)]
                else:
                    model = down + [submodule] + up
    
            self.use_attention = use_attention
            if use_attention:
                attention_conv = nn.Conv2d(outer_nc + input_nc, outer_nc + input_nc, kernel_size=1)
                attention_sigmoid = nn.Sigmoid()
                self.attention = nn.Sequential(*[attention_conv, attention_sigmoid])
    
            self.model = nn.Sequential(*model)
    
        def forward(self, x):
            if self.outermost:
                return self.model(x)
            else:
                ret = torch.cat([x, self.model(x)], 1)
                return self.attention(ret) * ret if self.use_attention else ret
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82

    3.模型结构——E

    class PartialConv2d(nn.Conv2d):
        def __init__(self, *args, **kwargs):
            # whether the mask is multi-channel or not
            if 'multi_channel' in kwargs:
                self.multi_channel = kwargs['multi_channel']
                kwargs.pop('multi_channel')
            else:
                self.multi_channel = False
    
            self.return_mask = True
    
            super(PartialConv2d, self).__init__(*args, **kwargs)
    
            if self.multi_channel:
                self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0],
                                                     self.kernel_size[1])
            else:
                self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
    
            self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \
                                 self.weight_maskUpdater.shape[3]
    
            self.last_size = (None, None, None, None)
            self.update_mask, self.mask_ratio = None, None
    
        def forward(self, input, mask_in=None):
            assert len(input.shape) == 4
            if mask_in is not None or self.last_size != tuple(input.shape):
                self.last_size = tuple(input.shape)
    
                with torch.no_grad():
                    if self.weight_maskUpdater.type() != input.type():
                        self.weight_maskUpdater = self.weight_maskUpdater.to(input)
    
                    if mask_in is None:
                        # if mask is not provided, create a mask
                        if self.multi_channel:
                            mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2],
                                              input.data.shape[3]).to(input)
                        else:
                            mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = mask_in
    
                    self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
                                                padding=self.padding, dilation=self.dilation, groups=1)
    
                    self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8)
                    self.update_mask = torch.clamp(self.update_mask, 0, 1)
                    self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
    
            raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
    
            if self.bias is not None:
                bias_view = self.bias.view(1, self.out_channels, 1, 1)
                output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
                output = torch.mul(output, self.update_mask)
            else:
                output = torch.mul(raw_out, self.mask_ratio)
    
            if self.return_mask:
                return output, self.update_mask
            else:
                return output
    
    
    class StyleEncoder(nn.Module):
        def __init__(self, style_dim, norm_layer=nn.BatchNorm2d):
            super(StyleEncoder, self).__init__()
            if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
                use_bias = norm_layer.func == nn.InstanceNorm2d
            else:
                use_bias = norm_layer == nn.InstanceNorm2d
            ndf = 64
            kw = 3
            padw = 0
            self.conv1f = PartialConv2d(3, ndf, kernel_size=kw, stride=2, padding=padw)
            self.relu1 = nn.ReLU(True)
            nf_mult = 1
    
            n = 1
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            self.conv2f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                        bias=use_bias)
            self.norm2f = norm_layer(ndf * nf_mult)
            self.relu2 = nn.ReLU(True)
    
            n = 2
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            self.conv3f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                        bias=use_bias)
            self.norm3f = norm_layer(ndf * nf_mult)
            self.relu3 = nn.ReLU(True)
    
            n = 3
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            self.conv4f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                        bias=use_bias)
            self.norm4f = norm_layer(ndf * nf_mult)
            self.relu4 = nn.ReLU(True)
    
            n = 4
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            self.conv5f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                        bias=use_bias)
            self.avg_pooling = nn.AdaptiveAvgPool2d(1)
            self.convs = nn.Conv2d(ndf * nf_mult, style_dim, kernel_size=1, stride=1)
    
        def forward(self, input, mask):
            """Standard forward."""
            xb = input
            mb = mask
    
            xb, mb = self.conv1f(xb, mb)
            xb = self.relu1(xb)
            xb, mb = self.conv2f(xb, mb)
            xb = self.norm2f(xb)
            xb = self.relu2(xb)
            xb, mb = self.conv3f(xb, mb)
            xb = self.norm3f(xb)
            xb = self.relu3(xb)
            xb, mb = self.conv4f(xb, mb)
            xb = self.norm4f(xb)
            xb = self.relu4(xb)
            xb, mb = self.conv5f(xb, mb)
            xb = self.avg_pooling(xb)
            s = self.convs(xb)
            return s
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132

    4.初始化模型与权重

    def init_weights(net, init_type='normal', init_gain=0.02):
        """Initialize network weights.
    
        Parameters:
            net (network)   -- network to be initialized
            init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
            init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
    
        We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
        work better for some applications. Feel free to try yourself.
        """
    
        def init_func(m):  # define the initialization function
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    init.normal_(m.weight.data, 0.0, init_gain)
                elif init_type == 'xavier':
                    init.xavier_normal_(m.weight.data, gain=init_gain)
                elif init_type == 'kaiming':
                    init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    init.orthogonal_(m.weight.data, gain=init_gain)
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)
            elif classname.find(
                    'BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
                init.normal_(m.weight.data, 1.0, init_gain)
                init.constant_(m.bias.data, 0.0)
    
        print('initialize network with %s' % init_type)
        net.apply(init_func)  # apply the initialization function 
    
    
    def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
        """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
        Parameters:
            net (network)      -- the network to be initialized
            init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
            gain (float)       -- scaling factor for normal, xavier and orthogonal.
            gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
    
        Return an initialized network.
        """
    
        if len(gpu_ids) > 0:
            assert (torch.cuda.is_available())
            net.to(gpu_ids[0])
            net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
        init_weights(net, init_type, init_gain=init_gain)
        return net
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53

    5.构建BargainNet

    class BargainNetModel:
        def __init__(self, netE, netG, style_dim=16, img_size=512, init_type='normal', init_gain=0.02, gpu_ids=[]):
            self.gpu_ids = gpu_ids
            self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
            self.lambda_tri = 0.01
            self.lambda_f2b = 1.0
            self.lambda_ff2 = 1.0
            self.loss_names = ['L1', 'tri']
            self.optimizers = []
            self.lr = 0.0002
            self.e_lr_ratio = 1.0
            self.g_lr_ratio = 1.0
            self.beta1 = 0.5
    
            self.style_dim = style_dim
            self.image_size = img_size
            self.netE = init_net(netE, init_type, init_gain, self.gpu_ids)
            self.netG = init_net(netG, init_type, init_gain, self.gpu_ids)
            self.relu = nn.ReLU()
            self.margin = 0.1
            self.tripletLoss = nn.TripletMarginLoss(margin=self.margin, p=2)
            self.criterionL1 = torch.nn.L1Loss()
            self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=self.lr * self.e_lr_ratio,
                                                betas=(self.beta1, 0.999))
            self.optimizers.append(self.optimizer_E)
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr * self.g_lr_ratio,
                                                betas=(self.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.schedulers = [
                lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0) for optimizer in self.optimizers
            ]
    
        def set_input(self, input):
            self.comp = input['comp'].to(self.device)
            self.real = input['real'].to(self.device)
            self.mask = input['mask'].to(self.device)
            self.inputs = torch.cat([self.comp, self.mask], 1).to(self.device)
            self.bg = 1.0 - self.mask
            self.real_f = self.real * self.mask
    
        def forward(self):
            self.bg_sty_vector = self.netE(self.real, self.bg)
            self.real_fg_sty_vector = self.netE(self.real, self.mask)
            self.bg_sty_map = self.bg_sty_vector.expand([1, self.style_dim, self.image_size, self.image_size])
            self.inputs_c2r = torch.cat([self.inputs, self.bg_sty_map], 1)
            self.harm = self.netG(self.inputs_c2r)
    
            self.harm_fg_sty_vector = self.netE(self.harm, self.mask)
            self.comp_fg_sty_vector = self.netE(self.comp, self.mask)
            self.fake_f = self.harm * self.mask
    
        def backward(self):
            self.loss_L1 = self.criterionL1(self.harm, self.real)
            self.loss_tri = (self.tripletLoss(self.real_fg_sty_vector, self.harm_fg_sty_vector,
                                              self.comp_fg_sty_vector) * self.lambda_ff2
                             + self.tripletLoss(self.harm_fg_sty_vector, self.bg_sty_vector,
                                                self.comp_fg_sty_vector) * self.lambda_f2b) * self.lambda_tri
            self.loss = self.loss_L1 + self.loss_tri
            self.loss.backward(retain_graph=True)
    
        def optimize_parameters(self):
            self.forward()
            self.optimizer_E.zero_grad()
            self.optimizer_G.zero_grad()
            self.backward()
            self.optimizer_E.step()
            self.optimizer_G.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67

    6.训练模型:

    主要改动是这里,给原来的训练方式加上了tqdm的进度条,现在可以在进度条上看到[“l1_loss”, “tri_loss”, “l1_loss + tri_loss”]的变化,更直观一些。

    from HarmonyDataset import Iharmony4Dataset就是读取数据的代码命名了,改成一样的就没问题

    # 参数设计按照官网的默认调用方式修改,官网的训练方式为:
    """
    python train.py --name  --model bargainnet --dataset_mode iharmony4 --is_train 1 --norm batch --preprocess resize_and_crop --gpu_ids 0 --save_epoch_freq 1 --input_nc 20 --lr 1e-4 --beta1 0.9 --lr_policy step --lr_decay_iters 6574200 --netG s2ad
    """
    G_net = UnetGenerator(20, 3, 8, 64, nn.BatchNorm2d, False, use_attention=True)
    E_net = StyleEncoder(16, norm_layer=nn.BatchNorm2d)
    
    if __name__ == "__main__":
        from HarmonyDataset import Iharmony4Dataset
    
        harmony_dataset = Iharmony4Dataset(dataset_root='/app/data/com/')
        datalen = len(harmony_dataset)
        model = BargainNetModel(E_net, G_net, gpu_ids=[])
        EPOCH = 20
        best_loss = 0.3  # best loss, default as 0.3
        for epoch in range(EPOCH):
            tqdm_bar = tqdm.tqdm(enumerate(harmony_dataset), total=datalen, desc='Epoch {}/{}'.format(epoch + 1, EPOCH))
            epoch_l1, epoch_tri = 0, 0
            for i, data in tqdm_bar:
                model.set_input(data)  # unpack data from a dataset and apply preprocessing
                model.optimize_parameters()  # calculate loss functions, get gradients, update network weights
                epoch_l1 += model.loss_L1.item()
                epoch_tri += model.loss_tri.item()
                tqdm_bar.set_postfix(L1=epoch_l1 / (i + 1), tri=epoch_tri / (i + 1),
                                     total=(epoch_l1 + epoch_tri) / (i + 1), best_loss=best_loss)
    
            if best_loss > (epoch_l1 + epoch_tri) / datalen:  # cache our latest model every  iterations
                print('the best model improve loss from {0} to {1}'.format(best_loss, (epoch_l1 + epoch_tri) / datalen))
                best_loss = (epoch_l1 + epoch_tri) / datalen
                # model save weights
                torch.save(model.netG.state_dict(), 'best_netG.pth')
                torch.save(model.netE.state_dict(), 'best_netE.pth')
    
            # update learning rates at the end of every epoch.
            for scheduler in model.schedulers:
                scheduler.step()
    
        # save the netG model complete
        # x = torch.zeros(1, 20, 512, 512, dtype=torch.float, requires_grad=False)
        # import hiddenlayer as h
        # myNetGraph = h.build_graph(netG, x)  # 建立网络模型图
        # myNetGraph.save(path='./demoModel-G', format='pdf')  # 保存网络模型图,可以设置 png 和 PDF 等.
    
    else:
        G_net.load_state_dict(torch.load('/app/checkpoints/best_net_G.pth'))
        E_net.load_state_dict(torch.load('/app/checkpoints/best_net_E.pth'))
        print('model load weights success')
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    总结

    预测的效果就这样,随便上百度找了个药品之后去背景,再粘上去。预测时把real换成comp即可。
    训练效果

    具体项目还请看论文的GitHub实现:https://github.com/bcmi/BargainNet-Image-Harmonization

    应该就这样了

    对模型有疑问的建议去看论文问作者,我只是代码的搬运工

    finish

  • 相关阅读:
    NIFI同步API接口数据
    Apple 推出全球开发者资源 —— 人人能编程
    CSI2与CDPHY学习
    大语言模型之十六-基于LongLoRA的长文本上下文微调Llama-2
    什么是Mybatis?Mybatis有什么作用?
    SSM - Springboot - MyBatis-Plus 全栈体系(六)
    Clear Code for Minimal API
    Oracle 数据库全表扫描的4种优化方法(DB)
    【qt】纯代码界面设计
    ubuntu 清理缓存
  • 原文地址:https://blog.csdn.net/qq_43190189/article/details/126029901