• Pytorch从零开始实战21


    Pytorch从零开始实战——Pix2Pix理论与实战

    本系列来源于365天深度学习训练营

    原作者K同学

    内容介绍

    Pix2Pix是一种用于用于图像翻译的通用框架,即图像到图像的转换。它在生成对抗网络的框架下进行训练。Pix2Pix的目标是将输入图像转换为输出图像,例如将黑白线稿转换为彩色图像,或者将地图转换为卫星图像等。Pix2Pix模型的训练通常需要大量的配对数据,即包含输入图像与相应输出图像的数据集。

    图像内容:指的是图像的固有内容,它是区分不同图像的依据。

    图像域:指在特定上下文中所涵盖的一组图像的集合,这些图像通常具有某种相似性或共同特征。图像域可以用来表示一类具有共同属性或内容的图像。在图像处理和计算机视觉领域,图像域常常被用于描述参与某项任务或问题的图像集合。

    图像翻译:是将一个物体的图像表征转换为该物体的另一个表征,例如根据皮包的轮廓图得到皮包的彩色图。也就是找到一个函数,能让域A的图像映射到域B,从而实现图像的跨域转换。

    Pix2Pix的三个核心技术:

    基于CGAN的损失函数:CGAN是Conditional Generative Adversarial Network的缩写,它将条件信息(如输入图像)作为生成器和判别器的输入,以帮助生成器生成更加逼真的输出。Pix2Pix使用了基于条件GAN的损失函数来指导生成器生成与目标图像更加接近的输出。

    基于U-Net的生成器:U-Net是一种用于图像分割的卷积神经网络结构,它由编码器和解码器组成,通过跳跃连接将低级特征与高级特征相结合,有助于保留更多的图像细节。Pix2Pix中的生成器采用了U-Net结构,以实现图像到图像的转换。
    下图为Pix2Pix的生成器的网络结构。
    在这里插入图片描述

    基于PatchGAN的判别器:PatchGAN是一种判别器的设计,对图像中的局部区域进行分类。Pix2Pix提出了将输入图像分成个图像块,这有助于提高模型对细节的感知,并使得判别器更能够捕捉到图像的局部结构信息。Pix2Pix使用了基于PatchGAN的判别器来评估生成的图像的真实度。

    数据集加载

    定义了一个自定义的数据集类ImageDataset,用于加载图像数据集。

    import glob
    import random
    import os
    import numpy as np
    from torch.utils.data import Dataset
    from PIL import Image
    import torchvision.transforms as transforms
    
    class ImageDataset(Dataset):
        def __init__(self, root, transforms_=None, mode="train"):
            self.transform = transforms.Compose(transforms_)
    
            self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
            if mode == "train":
                self.files.extend(sorted(glob.glob(os.path.join(root, "test") + "/*.*")))
    
        def __getitem__(self, index):
    
            img = Image.open(self.files[index % len(self.files)])
            w, h = img.size
            img_A = img.crop((0, 0, w / 2, h))
            img_B = img.crop((w / 2, 0, w, h))
    
            if np.random.random() < 0.5:
                img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
                img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")
    
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
    
            return {"A": img_A, "B": img_B}
    
        def __len__(self):
            return len(self.files)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    模型实现

    下面代码定义了一个U-Net生成器和一个PatchGAN鉴别器。

    import torch.nn as nn
    import torch.nn.functional as F
    import torch
    def weights_init_normal(m):
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("BatchNorm2d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    UNetDown 类和 UNetUp 类:这些类定义了 U-Net 结构中的下采样和上采样部分的层次。UNetDown 类用于定义 U-Net 的下采样部分,它包括卷积层、归一化层、激活函数层和丢弃层。UNetUp 类定义了 U-Net 的上采样部分,它包括转置卷积层、归一化层、激活函数层和丢弃层。

    GeneratorUNet 类:这是整个 U-Net 生成器的定义。它利用了之前定义的 UNetDown 和 UNetUp 类来构建一个完整的 U-Net 网络。在 init方法中,它初始化了 U-Net 的各个层次,并在 forward 方法中定义了数据在网络中的传播方式。

    class UNetDown(nn.Module):
        def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
            super(UNetDown, self).__init__()
            layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_size))
            layers.append(nn.LeakyReLU(0.2))
            if dropout:
                layers.append(nn.Dropout(dropout))
            self.model = nn.Sequential(*layers)
    
        def forward(self, x):
            return self.model(x)
    
    
    class UNetUp(nn.Module):
        def __init__(self, in_size, out_size, dropout=0.0):
            super(UNetUp, self).__init__()
            layers = [
                nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
                nn.InstanceNorm2d(out_size),
                nn.ReLU(inplace=True),
            ]
            if dropout:
                layers.append(nn.Dropout(dropout))
    
            self.model = nn.Sequential(*layers)
    
        def forward(self, x, skip_input):
            x = self.model(x)
            x = torch.cat((x, skip_input), 1)
    
            return x
    
    
    class GeneratorUNet(nn.Module):
        def __init__(self, in_channels=3, out_channels=3):
            super(GeneratorUNet, self).__init__()
    
            self.down1 = UNetDown(in_channels, 64, normalize=False)
            self.down2 = UNetDown(64, 128)
            self.down3 = UNetDown(128, 256)
            self.down4 = UNetDown(256, 512, dropout=0.5)
            self.down5 = UNetDown(512, 512, dropout=0.5)
            self.down6 = UNetDown(512, 512, dropout=0.5)
            self.down7 = UNetDown(512, 512, dropout=0.5)
            self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
    
            self.up1 = UNetUp(512, 512, dropout=0.5)
            self.up2 = UNetUp(1024, 512, dropout=0.5)
            self.up3 = UNetUp(1024, 512, dropout=0.5)
            self.up4 = UNetUp(1024, 512, dropout=0.5)
            self.up5 = UNetUp(1024, 256)
            self.up6 = UNetUp(512, 128)
            self.up7 = UNetUp(256, 64)
    
            self.final = nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.ZeroPad2d((1, 0, 1, 0)),
                nn.Conv2d(128, out_channels, 4, padding=1),
                nn.Tanh(),
            )
    
        def forward(self, x):
            # U-Net generator with skip connections from encoder to decoder
            d1 = self.down1(x)
            d2 = self.down2(d1)
            d3 = self.down3(d2)
            d4 = self.down4(d3)
            d5 = self.down5(d4)
            d6 = self.down6(d5)
            d7 = self.down7(d6)
            d8 = self.down8(d7)
            u1 = self.up1(d8, d7)
            u2 = self.up2(u1, d6)
            u3 = self.up3(u2, d5)
            u4 = self.up4(u3, d4)
            u5 = self.up5(u4, d3)
            u6 = self.up6(u5, d2)
            u7 = self.up7(u6, d1)
    
            return self.final(u7)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82

    discriminator_block 函数:这个函数定义了 PatchGAN 判别器中的一个“块”,包括一个卷积层、一个归一化层和一个 LeakyReLU 激活函数。

    Discriminator 类:这是整个 PatchGAN 判别器的定义。它由一系列卷积层和 LeakyReLU 激活函数层组成,用于从图像对中提取特征,并输出一个判别值,表示输入图像对是真实对还是生成对。

    class Discriminator(nn.Module):
        def __init__(self, in_channels=3):
            super(Discriminator, self).__init__()
    
            def discriminator_block(in_filters, out_filters, normalization=True):
                """Returns downsampling layers of each discriminator block"""
                layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
                if normalization:
                    layers.append(nn.InstanceNorm2d(out_filters))
                layers.append(nn.LeakyReLU(0.2, inplace=True))
                return layers
    
            self.model = nn.Sequential(
                *discriminator_block(in_channels * 2, 64, normalization=False),
                *discriminator_block(64, 128),
                *discriminator_block(128, 256),
                *discriminator_block(256, 512),
                nn.ZeroPad2d((1, 0, 1, 0)),
                nn.Conv2d(512, 1, 4, padding=1, bias=False)
            )
    
        def forward(self, img_A, img_B):
            # Concatenate image and condition image by channels to produce input
            img_input = torch.cat((img_A, img_B), 1)
            return self.model(img_input)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    开始训练

    主要功能包括:

    1.解析命令行参数:使用argparse.ArgumentParser()解析命令行参数,包括训练所需的超参数、数据集名称等。

    2.定义生成器和判别器模型:使用之前定义的GeneratorUNet和Discriminator类创建模型。

    3.配置优化器:使用Adam优化器来优化生成器和判别器的参数。

    4.加载数据集:使用PyTorch的DataLoader加载训练和验证数据集。

    5.定义损失函数:定义了GAN损失和像素级别的L1损失。

    6.训练过程:使用双重循环进行训练,其中外层循环遍历每个epoch,内层循环遍历每个batch。在每个batch内,首先训练生成器,然后训练判别器。训练过程中,会打印损失信息以及训练的进度。

    7.保存模型和生成示例图像:在每个epoch结束时,会保存生成器和判别器的模型参数,并周期性地生成一些示例图像以供可视化和评估模型效果。

    import argparse
    import time
    import datetime
    import sys
    
    import torchvision.transforms as transforms
    from torchvision.utils import save_image
    from torch.utils.data import DataLoader
    from torch.autograd import Variable
    
    from models import *
    from datasets import *
    
    import torch.nn as nn
    import torch.nn.functional as F
    import torch
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
    parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
    parser.add_argument("--dataset_name", type=str, default="data_facades", help="name of the dataset")
    parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_height", type=int, default=256, help="size of image height")
    parser.add_argument("--img_width", type=int, default=256, help="size of image width")
    parser.add_argument("--channels", type=int, default=3, help="number of image channels")
    parser.add_argument(
        "--sample_interval", type=int, default=500, help="interval between sampling of images from generators"
    )
    parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between model checkpoints")
    opt = parser.parse_args()
    print(opt)
    
    os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
    os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)
    
    cuda = True if torch.cuda.is_available() else False
    
    # Loss functions
    criterion_GAN = torch.nn.MSELoss()
    criterion_pixelwise = torch.nn.L1Loss()
    
    # Loss weight of L1 pixel-wise loss between translated image and real image
    lambda_pixel = 100
    
    # Calculate output of image discriminator (PatchGAN)
    patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)
    
    # Initialize generator and discriminator
    generator = GeneratorUNet()
    discriminator = Discriminator()
    
    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        criterion_GAN.cuda()
        criterion_pixelwise.cuda()
    
    if opt.epoch != 0:
        # Load pretrained models
        generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))
        discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)
    
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    
    # Configure dataloaders
    transforms_ = [
        transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    
    dataloader = DataLoader(
        ImageDataset("./%s" % opt.dataset_name, transforms_=transforms_),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )
    
    val_dataloader = DataLoader(
        ImageDataset("./%s" % opt.dataset_name, transforms_=transforms_, mode="val"),
        batch_size=10,
        shuffle=True,
        num_workers=1,
    )
    
    # Tensor type
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    
    
    def sample_images(batches_done):
        """Saves a generated sample from the validation set"""
        imgs = next(iter(val_dataloader))
        real_A = Variable(imgs["B"].type(Tensor))
        real_B = Variable(imgs["A"].type(Tensor))
        fake_B = generator(real_A)
        img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
        save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)
    
    
    # ----------
    #  Training
    # ----------
    
    if __name__ == '__main__':
        prev_time = time.time()
    
        for epoch in range(opt.epoch, opt.n_epochs):
            for i, batch in enumerate(dataloader):
    
                # Model inputs
                real_A = Variable(batch["B"].type(Tensor))
                real_B = Variable(batch["A"].type(Tensor))
    
                # Adversarial ground truths
                valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
                fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)
    
                # ------------------
                #  Train Generators
                # ------------------
    
                optimizer_G.zero_grad()
    
                # GAN loss
                fake_B = generator(real_A)
                pred_fake = discriminator(fake_B, real_A)
                loss_GAN = criterion_GAN(pred_fake, valid)
                # Pixel-wise loss
                loss_pixel = criterion_pixelwise(fake_B, real_B)
    
                # Total loss
                loss_G = loss_GAN + lambda_pixel * loss_pixel
    
                loss_G.backward()
    
                optimizer_G.step()
    
                # ---------------------
                #  Train Discriminator
                # ---------------------
    
                optimizer_D.zero_grad()
    
                # Real loss
                pred_real = discriminator(real_B, real_A)
                loss_real = criterion_GAN(pred_real, valid)
    
                # Fake loss
                pred_fake = discriminator(fake_B.detach(), real_A)
                loss_fake = criterion_GAN(pred_fake, fake)
    
                # Total loss
                loss_D = 0.5 * (loss_real + loss_fake)
    
                loss_D.backward()
                optimizer_D.step()
    
                # --------------
                #  Log Progress
                # --------------
    
                # Determine approximate time left
                batches_done = epoch * len(dataloader) + i
                batches_left = opt.n_epochs * len(dataloader) - batches_done
                time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
                prev_time = time.time()
    
                # Print log
                sys.stdout.write(
                    "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
                    % (
                        epoch,
                        opt.n_epochs,
                        i,
                        len(dataloader),
                        loss_D.item(),
                        loss_G.item(),
                        loss_pixel.item(),
                        loss_GAN.item(),
                        time_left,
                    )
                )
    
                # If at sample interval save image
                if batches_done % opt.sample_interval == 0:
                    sample_images(batches_done)
    
            if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
                # Save model checkpoints
                torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))
                torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203

    训练结果
    在这里插入图片描述

    总结

    Pix2Pix模型在许多领域都有广泛的应用,可以用于许多任务,如将黑白图像转换为彩色图像、将语义标签转换为真实图像、图像超分辨率等。 由于Pix2Pix使用了生成对抗网络,生成的图像可以在一定程度上理解为真实图像与生成图像之间的差异,使得生成结果更具可解释性。

    Pix2Pix的训练通常需要大量的配对数据,即包含输入图像与相应输出图像的数据集,以获得更好的性能和生成效果。

    并且,Pix2Pix模型的训练较为复杂,需要精细调节超参数、选择合适的损失函数以及处理训练不稳定等问题。

  • 相关阅读:
    2022-11-17 更高效的Cascades优化器 - Columbia Query Optimizer
    点成案例丨温度梯度培养箱在探究温度对植物发芽影响中的应用
    字节测试总监8年测试路,如何从功能测试转成自动化测试——愿测试人不再迷茫
    98%的人都不会使用这6种地图可视化方法,学会直接涨薪5K
    业界首创云原生安全检测双模型!安全狗重磅报告亮相数字中国建设峰会
    【C++】引用’&‘的深入解析
    Linux文件系统管理
    基于FPGA的I2C接口控制器(包含单字节和多字节读写)
    深入探讨进程间通信的重要性:理解不同的通信机制(下)
    微信核酸检测预约小程序系统毕业设计毕设(4)开题报告
  • 原文地址:https://blog.csdn.net/weixin_45959662/article/details/136725316