• 【生成式网络】入门篇(四):CycleGAN 的 代码和结果记录


    CycleGAN是一个里程碑式的工作,开启了unpaired的风格迁移的先河,斑马转马的效果还是很震惊。
    具体原理可以参考 https://zhuanlan.zhihu.com/p/402819206

    在这里插入图片描述
    老习惯,直接上code,然后按照code进行一些解释
    代码参考自 https://github.com/aitorzip/PyTorch-CycleGAN 相对比较简洁,我进行了一些小修改

    import os
    # os.chdir(os.path.dirname(__file__))
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader
    import torch.nn.functional as F
    import torchvision
    from torchvision import transforms
    from torchvision import datasets
    from torchvision import models
    from torch.utils.tensorboard import SummaryWriter
    import numpy as np
    from PIL import Image
    import argparse
    from glob import glob
    import random
    import itertools
    
    ## from https://github.com/aitorzip/PyTorch-CycleGAN
    
    sample_dir = 'samples_cycle_gan'
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir, exist_ok=True)
    
    writer = SummaryWriter(sample_dir)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    np.random.seed(0)
    torch.manual_seed(0)
    
    class ImageDataset(torch.utils.data.Dataset):
        def __init__(self, root, transforms=None, unaligned=False, mode='train'):
            self.transforms = transforms
            self.unaligned = unaligned
            self.files_A = sorted(glob(os.path.join(root, mode, 'A', '*.*')))
            self.files_B = sorted(glob(os.path.join(root, mode, 'B', '*.*')))
    
        def __getitem__(self, idx):
            img = Image.open(self.files_A[idx % len(self.files_A)]).convert('RGB')
            itemA = self.transforms(img)
    
            if self.unaligned:
                rand_idx = random.randint(0, len(self.files_B)-1)
                img = Image.open(self.files_B[rand_idx]).convert('RGB')
                itemB = self.transforms(img)
            else:
                img = Image.open(self.files_B[idx % len(self.files_B)]).convert('RGB')
                itemB = self.transforms(img)
    
            return {
                'A' : itemA,
                'B' : itemB
            }
    
        def __len__(self):
            return max(len(self.files_A), len(self.files_B))
    
    
    class ResidualBlock(nn.Module):
        def __init__(self, in_features):
            super(ResidualBlock, self).__init__()
    
            self.conv_block = nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_features, in_features, 3),
                nn.InstanceNorm2d(in_features),
                nn.ReLU(inplace=True),
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_features, in_features, 3),
                nn.InstanceNorm2d(in_features) 
            )
    
        def forward(self, x):
            return x + self.conv_block(x)
    
    class Generator(nn.Module):
        def __init__(self, input_nc, output_nc, n_res_blocks=9):
            super(Generator, self).__init__()
    
            # init basic conv block
            model = [
                nn.ReflectionPad2d(3),
                nn.Conv2d(input_nc, 64, 7),
                nn.InstanceNorm2d(64),
                nn.ReLU(inplace=True)
            ]
    
            # downsampling
            in_features = 64
            out_features = in_features * 2
            for _ in range(2):
                model += [
                    nn.Conv2d(in_features, out_features, 2, stride=2, padding=1),
                    nn.InstanceNorm2d(out_features),
                    nn.ReLU(inplace=True)
                ]
                in_features = out_features
                out_features = in_features * 2
            
            # residual blocks
            for _ in range(2):
                model += [ResidualBlock(in_features)]   
    
            # upsampling
            out_features = in_features //2
            for _ in range(2):
                model += [
                    nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                    nn.InstanceNorm2d(out_features),
                    nn.ReLU(inplace=True)
                ] 
                in_features = out_features
                out_features = in_features //2
            
            # output layer
            model += [
                nn.ReflectionPad2d(3),
                nn.Conv2d(64, output_nc, 11),
                nn.Tanh()
            ]
    
            self.model = nn.Sequential(*model)
    
        def forward(self, x):
            return self.model(x)
    
    class Discriminator(nn.Module):
        def __init__(self, input_nc):
            super(Discriminator, self).__init__()
    
            # A bunch of convolutions one after another
            self.model = nn.Sequential(
                nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(64, 128, 4, stride=2, padding=1),
                nn.InstanceNorm2d(128), 
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(128, 256, 4, stride=2, padding=1),
                nn.InstanceNorm2d(256), 
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(256, 512, 4, padding=1),
                nn.InstanceNorm2d(512), 
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(512, 1, 4, padding=1)
            )
    
        def forward(self, x):
            x = self.model(x)
            # average pooling and flatten
            return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
    
    class ReplayBuffer():
        def __init__(self, max_size=50):
            assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
            self.max_size = max_size
            self.data = []
    
        def push_and_pop(self, data):
            to_return = []
            for element in data.data:
                element = torch.unsqueeze(element, 0)
                if len(self.data) < self.max_size:
                    self.data.append(element)
                    to_return.append(element)
                else:
                    if random.uniform(0,1) > 0.5:
                        i = random.randint(0, self.max_size-1)
                        to_return.append(self.data[i].clone())
                        self.data[i] = element
                    else:
                        to_return.append(element)
            return torch.cat(to_return)
    
    class LambdaLR():
        def __init__(self, n_epochs, offset, decay_start_epoch):
            assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
            self.n_epochs = n_epochs
            self.offset = offset
            self.decay_start_epoch = decay_start_epoch
    
        def step(self, epoch):
            return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)
    
    def weights_init_normal(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant(m.bias.data, 0.0)
    
    
    def denorm(x):
        out = (x+1)/2
        return out.clamp(0, 1)
    
    # Networks
    input_nc = 3
    output_nc = 3
    learning_rate = 0.0002
    n_epochs = 200
    decay_epoch = 100
    start_epoch = 0
    batch_size = 16
    input_size = 256
    dataroot = 'data/cycle_gan/datasets/horse2zebra'
    
    netG_A2B = Generator(input_nc, output_nc).to(device)
    netG_B2A = Generator(output_nc, input_nc).to(device)
    netD_A = Discriminator(input_nc).to(device)
    netD_B = Discriminator(output_nc).to(device)
    
    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)
    
    # Losses
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    
    # optimizer
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 
                                lr=learning_rate, betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    
    # lr schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step)
    
    # Inputs & targets memory allocation
    target_real = torch.ones(batch_size, requires_grad=False).to(device)
    target_fake = torch.zeros(batch_size, requires_grad=False).to(device)
    
    
    # Dataset loader
    transforms_data = transforms.Compose([ 
                    transforms.Resize(int(input_size*1.12), Image.BICUBIC), 
                    transforms.RandomCrop(input_size), 
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 
                    ])
    
    dataset = ImageDataset(dataroot, transforms=transforms_data, unaligned=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16, drop_last=True)
    
    
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()
    
    
    ###### Training ######
    cnt = 0
    log_step = 10
    for epoch in range(start_epoch, n_epochs):
        for i, batch in enumerate(dataloader):
            # set model input
            real_A = batch['A'].to(device)
            real_B = batch['B'].to(device)
    
            ###### Generators ######
            # generators A2B and B2A
            optimizer_G.zero_grad()
    
            ### identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * 5.0    
    
            ### GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    
            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
    
            ### Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0
    
            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0
    
            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()
            
            optimizer_G.step()
            ###################################
    
            ###### Discriminator A ######
            optimizer_D_A.zero_grad()
            # real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)
            # fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
            # total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()
    
            ###### Discriminator B ######
            optimizer_D_B.zero_grad()
            # real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)
            # fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
            # total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()
            optimizer_D_B.step()
    
            cnt += 1
            if cnt % log_step == 0:
                print('Epoch [{}/{}], Step [{}], LossG: {:.4f}, loss_D_A: {:.4f}, loss_D_B: {:.4f}'.\
                    format(epoch, n_epochs, cnt, loss_G.item(), loss_D_A.item(), loss_D_B.item()))
    
                writer.add_scalar('LossG', loss_G.item(), global_step=cnt)
                writer.add_scalar('loss_D_A', loss_D_A.item(), global_step=cnt)
                writer.add_scalar('loss_D_B', loss_D_B.item(), global_step=cnt)
            if cnt % 100 == 0:
                writer.add_images('real_A', denorm(real_A), global_step=cnt)
                writer.add_images('fake_A', denorm(fake_A), global_step=cnt)
                writer.add_images('recovered_A', denorm(recovered_A), global_step=cnt)
                writer.add_images('real_B', denorm(real_B), global_step=cnt)
                writer.add_images('fake_B', denorm(fake_B), global_step=cnt)
                writer.add_images('recovered_B', denorm(recovered_B), global_step=cnt)
    
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
        # Save models checkpoints
        torch.save(netG_A2B.state_dict(), sample_dir + '/netG_A2B.pth')
        torch.save(netG_B2A.state_dict(), sample_dir + '/netG_B2A.pth')
        torch.save(netD_A.state_dict(), sample_dir + '/netD_A.pth')
        torch.save(netD_B.state_dict(), sample_dir + '/netD_B.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356

    我们来根据代码进行解读, 首先一个样本里是包含了A和B两张图,称为real_A 和 real_B。
    定义了生成网络netG_A2B和 netG_B2A

    先看Generators 部分

            ###### Generators ######
            # generators A2B and B2A
            optimizer_G.zero_grad()
    
            ### identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * 5.0    
    
            ### GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    
            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
    
            ### Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0
    
            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0
    
            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()
            
            optimizer_G.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    生成网络包含三部分的loss

    • identity loss。netG_A2B 是把A风格图像转换为B风格,那么我们应该保证把B风格图像丢进去,出来的依然是B风格的原图,这部分loss就叫 identity loss,同理,对于netG_B2A也由此约束。
    • GAN loss。就是场景的generator的loss,对亮哥generator而言,生成的fake图像应该让他label误判为real 的label。
    • cycle loss。把A丢进netG_A2B,生成B风格图后,再丢进netG_B2A,理论上应该转换回A风格,这部分约束就是cycle loss,同理,对于netG_B2A也由此约束。

    再看DiscriminatorA 部分, DiscriminatorB同理。
    就是正常GAN里的Discriminator loss,应该把真的识别为真,假的识别为假。

            optimizer_D_A.zero_grad()
            # real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)
            # fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
            # total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    可以再看看Generator的网络部分,
    整体结构跟fast style transfer 非常像,也是先降采样,再residual,最后上采样,并且也用了ReflectionPad2d。
    并且代码里用的是nn.InstanceNorm2d

    Discriminator就没太多可说的了,几层卷积下来,变成一个batchsize * 1 * h * w 的tensor,最后用一个avg_pool2d得到batchsize * 1 的分类结果,没有用全连接层。

    里面还需要提一下的,是用了一个ReplayBuffer机制,我的理解是在做分类的时候把fakeA和fakeB扔进buffer里,然后取出一个buffer里存的来,这样做分类的时候引入了别的batch里的数据,我猜测可能是为了避免discriminator能力集中在区分这种一对对的样本上,而是变得可以见到更多正负样本对。

    不过也是因为这个机制,导致我训练的时候打印出的原图和fake图不是一一对应的,不方便看效果,不过这个很容易修改,我就偷懒了。

    我们看效果 A是普通马,B是斑马
    转换之后,这是变普遍马的效果
    在这里插入图片描述

    这是变斑马的效果

    在这里插入图片描述
    不算特别好,比文章的效果差远了,应该还有很多地方需要调优的,建议想要文章效果的童鞋试试官方代码 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

  • 相关阅读:
    用PHP组合数组,生成笛卡尔积的几个例子
    Qt-FFmpeg开发-音频解码为PCM文件(9)
    安装DevEco Studio 3.0 Beta2
    06JVM_类加载器
    2022-07-19 达梦数据库 连接实例、执行脚本、系统命令
    基于Java+SpringBoot+Thymeleaf+Mysql在线购物网站商城系统设计实现
    网络安全副业如何年入数十万 (如何让你的副业超过主页)
    使用Golang策略和最佳实践高效处理一百万个请求
    vue 02
    人体状态检测YOLOV8 NANO
  • 原文地址:https://blog.csdn.net/fangjin_kl/article/details/128117396