• GAN-生成对抗网络(Pytorch)合集(2)--pixtopix-CycleGAN


    pixtopix(像素到像素)

    原文连接:https://arxiv.org/pdf/1611.07004.pdf
    输入一个域的图片转换为另一个域的图片(白天照片转成黑夜)
    如下图,输入标记图片,输出真实图片缺点就是训练集两个域的图片要一一对应,所以叫pixtopix,
    在这里插入图片描述

    网络结构有点复杂,用到了语义分割的UNET网络结构
    在这里插入图片描述
    数据集
    地址忘了,也是官方的,想起来补
    代码:这里是建筑物labels to facade的例子

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils import data
    import torchvision
    from torchvision import transforms
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    import glob
    from PIL import Image
    
    # jpg是原始图片
    images_path = glob.glob(r'base\*.jpg')
    annos_path = glob.glob(r'base\*.png')
    # png是分割的图片
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((256, 256)),
        transforms.Normalize(0.5, 0.5)
    ])
    
    
    class CMP_dataset(data.Dataset):
        def __init__(self, imgs_path, annos_path):
            self.imgs_path = imgs_path
            self.annos_path = annos_path
    
        def __getitem__(self, item):
            img_path = self.imgs_path[item]
            anno_path = self.annos_path[item]
            pil_img = Image.open(img_path)
            pil_img = transform(pil_img)
    
            anno_img = Image.open(anno_path)
            anno_img = anno_img.convert('RGB')
            pil_anno = transform(anno_img)
            return pil_anno, pil_img
    
        def __len__(self):
            return len(self.imgs_path)
    
    
    dataset = CMP_dataset(images_path, annos_path)
    batchsize = 32
    dataloader = data.DataLoader(dataset,
                                 batch_size=batchsize,
                                 shuffle=True)
    
    annos_batch, images_batch = next(iter(dataloader))
    
    for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
        anno = (anno.permute(1, 2, 0).numpy()+1)/2
        img = (img.permute(1, 2, 0).numpy()+1)/2
        plt.subplot(3, 2, i*2+1)
        plt.title('input_img')
        plt.imshow(anno)
    
        plt.subplot(3, 2, i*2+2)
        plt.title('output_img')
        plt.imshow(img)
    plt.show()
    
    # 定义下采样模块
    class Downsample(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Downsample, self).__init__()
            self.conv_relu = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, 2, 1),
                nn.LeakyReLU(inplace=True)
            )
            self.bn = nn.BatchNorm2d(out_channels)
    
        def forward(self, x, is_bn=True):
            x = self.conv_relu(x)
            if is_bn:
                x = self.bn(x)
            return x
    
    
    # 定义上采样模块
    class Upsample(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Upsample, self).__init__()
            self.upconv_relu = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
                                   output_padding=1),
                nn.LeakyReLU(inplace=True)
            )
            self.bn = nn.BatchNorm2d(out_channels)
    
        def forward(self, x, is_drop=False):
            x = self.upconv_relu(x)
            x = self.bn(x)
            if is_drop:
                x = F.dropout2d(x)
            return x
    
    
    # 定义生成器,包含6个下采样,5上采样,1输出
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.down1 = Downsample(3, 64)   # 64,128,128
            self.down2 = Downsample(64, 128)  # 128,64,64
            self.down3 = Downsample(128, 256)  # 256,32,32
            self.down4 = Downsample(256, 512)  # 512, 16,16
            self.down5 = Downsample(512, 512)  # 512,8,8
            self.down6 = Downsample(512, 512)  # 512, 4,4
    
            self.up1 = Upsample(512, 512)      # 512 ,8,8
            self.up2 = Upsample(1024, 512)    # 512, 16,16
            self.up3 = Upsample(1024, 256)   # 256, 32,32
            self.up4 = Upsample(512, 128)   # 128,64,64
            self.up5 = Upsample(256, 64)   # 64,128,128
    
            self.last = nn.ConvTranspose2d(128, 3,
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=1)
        def forward(self,x):
            x1 = self.down1(x)
            x2 = self.down2(x1)
            x3 = self.down3(x2)
            x4 = self.down4(x3)
            x5 = self.down5(x4)
            x6 = self.down6(x5)
    
            x6 = self.up1(x6, is_drop=True)
            x6 = torch.cat([x6, x5], dim=1)
    
            x6 = self.up2(x6, is_drop=True)
            x6 = torch.cat([x6, x4], dim=1)
    
            x6 = self.up3(x6, is_drop=True)
            x6 = torch.cat([x6, x3], dim=1)
    
            x6 = self.up4(x6, is_drop=True)
            x6 = torch.cat([x6, x2], dim=1)
    
            x6 = self.up5(x6)
            x6 = torch.cat([x6, x1], dim=1)
    
            x6 = torch.tanh(self.last(x6))
    
            return x6
    
    
    # 定义判别器 输入anno + img
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.down1 = Downsample(6, 64)  # 64*128*128
            self.down2 = Downsample(64, 128)  # 128*64*64
            self.conv1 = nn.Conv2d(128, 256, 3)
            self.bn1 = nn.BatchNorm2d(256)
            self.conv2 = nn.Conv2d(256, 1, 3)
    
        def forward(self, anno, img):
            x = torch.cat([anno, img], axis=1)  # batch*6*h*w
            x = self.down1(x, is_bn=False)
            x = self.down2(x)
            x = F.dropout2d(self.bn1(F.leaky_relu(self.conv1(x))))
            x = torch.sigmoid(self.conv2(x))   # batch*1* 60*60
            return x
    
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        print('using cuda:', torch.cuda.get_device_name(0))
    else:
        print(device)
    
    Gen = Generator().to(device)
    Dis = Discriminator().to(device)
    
    d_optimizer = torch.optim.Adam(Dis.parameters(), lr=1e-3, betas=(0.5, 0.999))
    g_optimizer = torch.optim.Adam(Gen.parameters(), lr=1e-3, betas=(0.5, 0.999))
    # loss
    # cgan损失
    loss_fn = torch.nn.BCELoss()
    # L1-loss 后面计算,求差绝对值的求和
    # 绘图
    def generator_images(model, test_anno, test_real):
        prediction = model(test_anno).permute(0, 2, 3, 1).detach().cpu().numpy()
        test_anno = test_anno.permute(0, 2, 3, 1).detach().cpu().numpy()
    
        test_real = test_real.permute(0, 2, 3, 1).detach().cpu().numpy()
        plt.figure(figsize=(10, 10))
        display_list = [test_anno[0], test_real[0], prediction[0]]
        title = ['input', 'ground truth', 'output']
        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.title(title[i])
            plt.imshow(display_list[i])
            plt.axis('off')
        plt.show()
    
    # 加载extend为测试
    test_imgs_path = glob.glob('extended/*.jpg')
    test_annos_path = glob.glob('extended/*.png')
    
    test_dataset = CMP_dataset(test_imgs_path, test_annos_path)
    test_daloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batchsize
    )
    # 返回一个批次
    
    annos_batch, images_batch = next(iter(dataloader))
    
    plt.figure(figsize=(6, 10))
    for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
        anno = (anno.permute(1, 2, 0).numpy()+1)/2
        img = (img.permute(1, 2, 0).numpy()+1)/2
        plt.subplot(3, 2, i*2+1)
        plt.title('input_img')
        plt.imshow(anno)
    
        plt.subplot(3, 2, i*2+2)
        plt.title('output_img')
        plt.imshow(img)
    plt.show()
    
    annos_batch, images_batch = annos_batch.to(device), images_batch.to(device)
    LAMBDA = 7  # L1损失权重
    
    D_loss = []
    G_loss = []
    for epoch in range(300):
        D_epoch_loss = 0
        G_epoch_loss = 0
        count = len(dataloader)
        for step, (annos, imgs) in enumerate(dataloader):
            imgs = imgs.to(device)
            annos = annos.to(device)
    
            d_optimizer.zero_grad()
            disc_real_output = Dis(annos, imgs)  # 输入真实成对图片
            d_real_loss = loss_fn(disc_real_output, torch.ones_like(disc_real_output,
                                                                    device=device)
                                  )
            d_real_loss.backward()
    
            gen_output = Gen(annos)
            dis_gen_output = Dis(annos, gen_output.detach())
            d_fake_loss = loss_fn(dis_gen_output, torch.zeros_like(dis_gen_output,
                                                                   device=device)
                                  )
            d_fake_loss.backward()
    
            disc_loss = d_real_loss + d_fake_loss
    
            d_optimizer.step()
    
            disc_gen_out = Dis(annos, gen_output)
            gen_loss_crossentropyloss = loss_fn(disc_gen_out,
                                                torch.ones_like(disc_gen_out,
                                                                device=device)
                                                )
            gen_l1_loss = torch.mean(torch.abs(gen_output - imgs))
            gen_loss = LAMBDA * gen_l1_loss + gen_loss_crossentropyloss
            gen_loss.backward()
            g_optimizer.step()
    
            with torch.no_grad():
                D_epoch_loss += disc_loss.item()
                G_epoch_loss += gen_loss.item()
        with torch.no_grad():
            D_epoch_loss /= count
            G_epoch_loss /= count
            D_loss.append(D_epoch_loss)
            G_loss.append(G_epoch_loss)
            print('Epoch', epoch)
            generator_images(Gen, annos_batch, images_batch)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277

    给动漫素描自动上色的(AI上色)移步我的kaggle
    https://www.kaggle.com/code/jiyuanhai/pix2pix-test-pytorch

    CycleGAN

    这个厉害👍,我愿称之为最强,克服了pixtopix需要数据集一一对应的缺点
    论文地址:https://arxiv.org/pdf/1703.10593.pdf
    【推荐同济子豪兄】或者看论文详解:https://www.bilibili.com/video/BV1Ya411a78P?spm_id_from=333.999.0.0&vd_source=66d85dad339b02807124d27ef76332c9
    B站也有很多讲的不错的视频。
    创新型的提出了循环一致性损失,具体技术不多赘述了,有些复杂。

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils import data
    import torchvision
    from torchvision import transforms
    
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    import glob
    from PIL import Image
    import itertools
    
    apples_path = glob.glob(r'E:\深度之眼深度学习\GAN生成对抗网络实战(PyTorch版)\00、代码+课件+数据\cyclegan-pytorch参考代码—日月光华\data\trainA\*.jpg')
    
    # 画图显示
    # plt.figure(figsize=(8, 8))
    # for i, imh_path in enumerate(apples_path[:4]):
    #     img = Image.open(imh_path)
    #     np_image = np.array(img)
    #     plt.subplot(2, 2, i+1)
    #     plt.imshow(np_image)
    #     plt.title(str(np_image.shape))
    # plt.show()
    
    oranges_path = glob.glob(r'E:\深度之眼深度学习\GAN生成对抗网络实战(PyTorch版)\00、代码+课件+数据\cyclegan-pytorch参考代码—日月光华\data\trainB\*.jpg')
    
    # plt.figure(figsize=(8, 8))
    # for i, imh_path in enumerate(oranges_path[:4]):
    #     img = Image.open(imh_path)
    #     np_image = np.array(img)
    #     plt.subplot(2, 2, i+1)
    #     plt.imshow(np_image)
    #     plt.title(str(np_image.shape))
    # plt.show()
    apples_test_path = glob.glob(r'E:\深度之眼深度学习\GAN生成对抗网络实战(PyTorch版)\00、代码+课件+数据\cyclegan-pytorch参考代码—日月光华\data\trainA\*.jpg')
    
    #数据集已经处理成了256,不用裁减
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5)
    ])
    
    class AO_Dataset(data.Dataset):
        def __init__(self, img_path):  # 初始化方法
            self.img_path = img_path
    
        def __getitem__(self, index):
            imgpath = self.img_path[index]
            pil_img = Image.open(imgpath)
            pil_img = transform(pil_img)
            return pil_img
    
        def __len__(self):
            return len(self.img_path)
    
    
    apple_dataset = AO_Dataset(apples_path)
    orange_dataset = AO_Dataset(oranges_path)
    apple_test_dataset = AO_Dataset(apples_test_path)
    
    BATHSIZE = 2
    NUMWORKERS = 10
    
    apple_dataloader = data.DataLoader(apple_dataset,
                                       batch_size=BATHSIZE,
                                       shuffle=True,
                                       #num_workers=NUMWORKERS
                                       )
    
    orange_dataloader = data.DataLoader(orange_dataset,
                                        batch_size=BATHSIZE,
                                        shuffle=True,
                                        #num_workers=NUMWORKERS
                                        )
    apple_dl_test = data.DataLoader(
        apple_test_dataset,
        batch_size=BATHSIZE,
        shuffle=True
    )
    # 创建模型
    # 定义下采样模块
    class Downsample(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Downsample, self).__init__()
            self.conv_relu = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, 2, 1),
                nn.LeakyReLU(inplace=True)
            )
            self.bn = nn.InstanceNorm2d(out_channels)
    
        def forward(self, x, is_bn=True):
            x = self.conv_relu(x)
            if is_bn:
                x = self.bn(x)
            return x
    
    # 定义上采样模块
    class Upsample(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Upsample, self).__init__()
            self.upconv_relu = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
                                   output_padding=1),
                nn.LeakyReLU(inplace=True)
            )
            self.bn = nn.InstanceNorm2d(out_channels)
    
        def forward(self, x, is_drop=False):
            x = self.upconv_relu(x)
            x = self.bn(x)
            if is_drop:
                x = F.dropout2d(x)
            return x
    
    # 定义生成器,包含6个下采样,5上采样,1输出
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.down1 = Downsample(3, 64)   # 64,128,128
            self.down2 = Downsample(64, 128)  # 128,64,64
            self.down3 = Downsample(128, 256)  # 256,32,32
            self.down4 = Downsample(256, 512)  # 512, 16,16
            self.down5 = Downsample(512, 512)  # 512,8,8
            self.down6 = Downsample(512, 512)  # 512, 4,4
    
            self.up1 = Upsample(512, 512)      # 512 ,8,8
            self.up2 = Upsample(1024, 512)    # 512, 16,16
            self.up3 = Upsample(1024, 256)   # 256, 32,32
            self.up4 = Upsample(512, 128)   # 128,64,64
            self.up5 = Upsample(256, 64)   # 64,128,128
    
            self.last = nn.ConvTranspose2d(128, 3,
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=1)
        def forward(self,x):
            x1 = self.down1(x)
            x2 = self.down2(x1)
            x3 = self.down3(x2)
            x4 = self.down4(x3)
            x5 = self.down5(x4)
            x6 = self.down6(x5)
    
            x6 = self.up1(x6, is_drop=True)
            x6 = torch.cat([x6, x5], dim=1)
    
            x6 = self.up2(x6, is_drop=True)
            x6 = torch.cat([x6, x4], dim=1)
    
            x6 = self.up3(x6, is_drop=True)
            x6 = torch.cat([x6, x3], dim=1)
    
            x6 = self.up4(x6, is_drop=True)
            x6 = torch.cat([x6, x2], dim=1)
    
            x6 = self.up5(x6)
            x6 = torch.cat([x6, x1], dim=1)
    
            x6 = torch.tanh(self.last(x6))
    
            return x6
    
    # 定义判别器 输入
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.down1 = Downsample(3, 64)  # 128
            self.down2 = Downsample(64, 128)  # 64
            self.last = nn.Conv2d(128, 1, 3)
    
        def forward(self, img):
            x = self.down1(img)
            x = self.down2(x)
            x = torch.sigmoid(self.last(x))   # batch*1* 60*60
            return x
    
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        print('using cuda:', torch.cuda.get_device_name(0))
    else:
        print(device)
    
    # 初始化两个生成器(A->B  B->A),
    gen_AB = Generator().to(device)
    gen_BA = Generator().to(device)
    
    # 两个判别器
    dis_A = Discriminator().to(device)
    dis_B = Discriminator().to(device)
    
    # 损失函数
    # 1,对抗loss  BCE
    # 2,cycle consistance loss
    # 3,identit loss
    
    BECLoss = torch.nn.BCELoss()
    L1_loss = torch.nn.L1Loss()
    gen_optimizer = torch.optim.Adam(
        itertools.chain(
            gen_AB.parameters(),
            gen_BA.parameters()
        ),
        lr=2e-4,
        betas=(0.5, 0.999)
    )
    
    dis_A_optimizer = torch.optim.Adam(
        dis_A.parameters(),
        lr=2e-4,
        betas=(0.5, 0.999)
    )
    
    dis_B_optimizer = torch.optim.Adam(
        dis_B.parameters(),
        lr=2e-4,
        betas=(0.5, 0.999)
    )
    
    
    def generate_images(model, test_input):
        prediction = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy()
        test_input = test_input.permute(0, 2, 3, 1).cpu().numpy()
        plt.figure(figsize=(10, 6))
        display_list = [test_input[0], prediction[0]]
        title = ['Input Image', 'Genrated Image']
        for i in range(2):
            plt.subplot(1, 2, i+1)
            plt.title(title[i])
            plt.imshow(display_list[i] * 0.5 + 0.5)
            plt.axis('off')
        plt.show()
    
    
    test_batch = next(iter(apple_dl_test))
    test_input = torch.unsqueeze(test_batch[0], 0).to(device)
    
    D_loss = []
    G_loss = []
    epoches = 50
    for epoch in range(epoches):
        d_epoch_loss = 0
        g_epoch_loss = 0
        for step, (real_A, real_B) in enumerate(zip(apple_dataloader, orange_dataloader)):
            real_A = real_A.to(device)
            real_B = real_B.to(device)
    
            # 训练生成器
            gen_optimizer.zero_grad()
    
            # identity loss
            same_B = gen_AB(real_B)
            identity_B_loss = L1_loss(same_B, real_B)
    
            same_A = gen_BA(real_A)
            identity_A_loss = L1_loss(same_A, real_A)
    
            # 对抗损失 gan loss
            fake_B = gen_AB(real_A)
            D_pre_fake_B = dis_B(fake_B)
            gen_loss_AB = BECLoss(D_pre_fake_B,
                    torch.ones_like(D_pre_fake_B, device=device))
    
            fake_A = gen_BA(real_B)
            D_pre_fake_A = dis_A(fake_A)
            gen_loss_BA = BECLoss(D_pre_fake_A,
                    torch.ones_like(D_pre_fake_A, device=device))
    
            # 循环一致性损失
            recovered_A = gen_BA(fake_B)
            cycle_loss_ABA = L1_loss(recovered_A, real_A)
    
            recovered_B = gen_AB(fake_A)
            cycle_loss_BAB = L1_loss(recovered_B, real_B)
    
            g_loss = identity_A_loss +identity_B_loss +gen_loss_AB +\
                     gen_loss_BA+cycle_loss_ABA+cycle_loss_BAB
    
            g_loss.backward()
            gen_optimizer.step()
    
            # dis_A train
            dis_A_optimizer.zero_grad()
            dis_A_real_output = dis_A(real_A)
            dis_A_real_loss = BECLoss(dis_A_real_output,
                                      torch.ones_like(dis_A_real_output, device=device))
            dis_A_fake_output = dis_A(fake_A.detach())
            dis_A_fake_loss = BECLoss(dis_A_fake_output,
                                        torch.zeros_like(dis_A_fake_output, device=device))
    
            dis_A_loss = dis_A_real_loss + dis_A_fake_loss
            dis_A_loss.backward()
            dis_A_optimizer.step()
    
            # dis_B train
            dis_B_optimizer.zero_grad()
            dis_B_real_output = dis_B(real_B)
            dis_B_real_loss = BECLoss(dis_B_real_output,
                                      torch.ones_like(dis_B_real_output, device=device))
    
            dis_B_fake_output = dis_B(fake_B.detach())
            dis_B_fake_loss = BECLoss(dis_B_fake_output,
                                      torch.zeros_like(dis_B_fake_output, device=device))
            dis_B_loss = dis_B_fake_loss + dis_B_real_loss
            dis_B_loss.backward()
            dis_B_optimizer.step()
    
            with torch.no_grad():
                g_epoch_loss += g_loss.item()
                d_epoch_loss += (dis_A_loss + dis_B_loss).item()
        with torch.no_grad():
            g_epoch_loss /= (step+1)
            d_epoch_loss /= (step+1)
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)
            print('Epoch:', epoch+1)
            print('g_epoch_loss:', g_epoch_loss)
            print('d_epoch_loss:', d_epoch_loss)
            generate_images(gen_AB, test_input)  # test_input is apple
    
    torch.save(gen_AB, 'Gen_AB.pth', _use_new_zipfile_serialization=False)
    torch.save(gen_BA, 'Gen_BA.pth', _use_new_zipfile_serialization=False)
    torch.save(dis_B, 'Dis_B.pth', _use_new_zipfile_serialization=False)
    torch.save(dis_A, 'Dis_A.pth', _use_new_zipfile_serialization=False)
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
  • 相关阅读:
    Springboot:统一异常处理
    Rockwell EDI 855 采购订单确认报文详解
    如何用mindspore高阶API实现互学习
    LeetCode 面试题 03.01. 三合一
    【21天python打卡】第16天 python经典案例(2)
    docker基础认知(镜像+容器+仓库+客户端与服务器)
    wx.getPrivacySetting 小程序隐私保护指引的使用(复制粘贴即用)
    LeetCode-779. 第K个语法符号【递归,绝对好理解】
    大数据Hadoop入门教程 | (一)概论
    Linux磁盘分区
  • 原文地址:https://blog.csdn.net/qq_45882032/article/details/125620988