• 第G7周:Semi-Supervised GAN 理论与实战


     🍨 本文为🔗365天深度学习训练营 中的学习记录博客

      🍦 参考文章:365天深度学习训练营-第G7周:Semi-Supervised GAN 理论与实战(训练营内部成员可读)

      🍖 原作者:K同学啊|接辅导、项目定制

     🏡 运行环境:
    电脑系统:Windows 10
    语言环境:python 3.10
    编译器:Pycharm 2022.1.1
    深度学习环境:Pytorch  


    目录

    一、理论知识讲解

    二、代码实现

    1、配置代码 

     2、初始化权重

    3、定义算法模型

    4、配置模型

     5、训练模型


    一、理论知识讲解

    该算法将产生式对抗网络(GAN) 拓展到半监督学习,通过强制判别器D来输出类别标签。我们
    在一个数据集上训练一个生成器G以及一个判别器D,输入是N类当中的一个。在训练的时候,判别器D被用于预测输入是属于N+1类中的哪一个,这个N+1是对应了生成器G的输出,这里的判别器
    D同时也充当起了分类器C的效果。这种方法可以用于训练效果更好的判别器D,并且可以比普通的GAN产性更加高质量的样本。Semi-Supervised GAN有如下优点:
    (1)作者对GANs做了一个新的扩展,允许它同时学习一个生成模型和一个分类器。我们把这个 扩展叫做半监督GAN或SGAN
    (2)论文实验结果表明,SGAN在有限数据集比没有生成部分的基准分类器提升了分类性能
    (3)论文实验结果表明,SGAN可以显著地提升生成样本的质量并降低生成器的训练时间。 

    二、代码实现

    1、配置代码 
    1. import argparse
    2. import os
    3. import numpy as np
    4. import math
    5. import torchvision.transforms as transforms
    6. from torchvision.utils import save_image
    7. from torch.utils.data import DataLoader
    8. from torchvision import datasets
    9. from torch.autograd import Variable
    10. import torch.nn as nn
    11. import torch.nn.functional as F
    12. import torch
    13. os.makedirs("images", exist_ok=True)
    14. parser = argparse.ArgumentParser()
    15. parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs of training")
    16. parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    17. parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    18. parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    19. parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    20. parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    21. parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    22. parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
    23. parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
    24. parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    25. parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
    26. opt = parser.parse_args(args=[])
    27. print(opt)
    28. cuda = True if torch.cuda.is_available() else False
    Namespace(n_epochs=2, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=2, latent_dim=100, num_classes=10, img_size=32, channels=1, sample_interval=400)
     2、初始化权重
    1. def weights_init_normal(m):
    2. classname = m.__class__.__name__
    3. if classname.find("Conv") != -1:
    4. torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    5. elif classname.find("BatchNorm") != -1:
    6. torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
    7. torch.nn.init.constant_(m.bias.data, 0.0)
    3、定义算法模型
    1. class Generator(nn.Module):
    2. def __init__(self):
    3. super(Generator, self).__init__()
    4. self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)
    5. self.init_size = opt.img_size // 4 # Initial size before upsampling
    6. self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
    7. self.conv_blocks = nn.Sequential(
    8. nn.BatchNorm2d(128),
    9. nn.Upsample(scale_factor=2),
    10. nn.Conv2d(128, 128, 3, stride=1, padding=1),
    11. nn.BatchNorm2d(128, 0.8),
    12. nn.LeakyReLU(0.2, inplace=True),
    13. nn.Upsample(scale_factor=2),
    14. nn.Conv2d(128, 64, 3, stride=1, padding=1),
    15. nn.BatchNorm2d(64, 0.8),
    16. nn.LeakyReLU(0.2, inplace=True),
    17. nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
    18. nn.Tanh(),
    19. )
    20. def forward(self, noise):
    21. out = self.l1(noise)
    22. out = out.view(out.shape[0], 128, self.init_size, self.init_size)
    23. img = self.conv_blocks(out)
    24. return img
    25. class Discriminator(nn.Module):
    26. def __init__(self):
    27. super(Discriminator, self).__init__()
    28. def discriminator_block(in_filters, out_filters, bn=True):
    29. """Returns layers of each discriminator block"""
    30. block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
    31. if bn:
    32. block.append(nn.BatchNorm2d(out_filters, 0.8))
    33. return block
    34. self.conv_blocks = nn.Sequential(
    35. *discriminator_block(opt.channels, 16, bn=False),
    36. *discriminator_block(16, 32),
    37. *discriminator_block(32, 64),
    38. *discriminator_block(64, 128),
    39. )
    40. # The height and width of downsampled image
    41. ds_size = opt.img_size // 2 ** 4
    42. # Output layers
    43. self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
    44. self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())
    45. def forward(self, img):
    46. out = self.conv_blocks(img)
    47. out = out.view(out.shape[0], -1)
    48. validity = self.adv_layer(out)
    49. label = self.aux_layer(out)
    50. return validity, label
    4、配置模型
    1. # Loss functions
    2. adversarial_loss = torch.nn.BCELoss()
    3. auxiliary_loss = torch.nn.CrossEntropyLoss()
    4. # Initialize generator and discriminator
    5. generator = Generator()
    6. discriminator = Discriminator()
    7. if cuda:
    8. generator.cuda()
    9. discriminator.cuda()
    10. adversarial_loss.cuda()
    11. auxiliary_loss.cuda()
    12. # Initialize weights
    13. generator.apply(weights_init_normal)
    14. discriminator.apply(weights_init_normal)
    15. # Configure data loader
    16. os.makedirs("../../data/mnist", exist_ok=True)
    17. dataloader = torch.utils.data.DataLoader(
    18. datasets.MNIST(
    19. "../../data/mnist",
    20. train=True,
    21. download=True,
    22. transform=transforms.Compose(
    23. [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    24. ),
    25. ),
    26. batch_size=opt.batch_size,
    27. shuffle=True,
    28. )
    29. # Optimizers
    30. optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    31. optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    32. FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    33. LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
    Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz
    
    Extracting ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw
    
    Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz
    
    Extracting ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw
    
    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz
    
    Extracting ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw
    
    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz
    
    Extracting ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw
     5、训练模型
    1. # ----------
    2. # Training
    3. # ----------
    4. for epoch in range(opt.n_epochs):
    5. for i, (imgs, labels) in enumerate(dataloader):
    6. batch_size = imgs.shape[0]
    7. # Adversarial ground truths
    8. valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
    9. fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
    10. fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)
    11. # Configure input
    12. real_imgs = Variable(imgs.type(FloatTensor))
    13. labels = Variable(labels.type(LongTensor))
    14. # -----------------
    15. # Train Generator
    16. # -----------------
    17. optimizer_G.zero_grad()
    18. # Sample noise and labels as generator input
    19. z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
    20. # Generate a batch of images
    21. gen_imgs = generator(z)
    22. # Loss measures generator's ability to fool the discriminator
    23. validity, _ = discriminator(gen_imgs)
    24. g_loss = adversarial_loss(validity, valid)
    25. g_loss.backward()
    26. optimizer_G.step()
    27. # ---------------------
    28. # Train Discriminator
    29. # ---------------------
    30. optimizer_D.zero_grad()
    31. # Loss for real images
    32. real_pred, real_aux = discriminator(real_imgs)
    33. d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
    34. # Loss for fake images
    35. fake_pred, fake_aux = discriminator(gen_imgs.detach())
    36. d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2
    37. # Total discriminator loss
    38. d_loss = (d_real_loss + d_fake_loss) / 2
    39. # Calculate discriminator accuracy
    40. pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
    41. gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)
    42. d_acc = np.mean(np.argmax(pred, axis=1) == gt)
    43. d_loss.backward()
    44. optimizer_D.step()
    45. batches_done = epoch * len(dataloader) + i
    46. if batches_done % opt.sample_interval == 0:
    47. save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
    48. print(
    49. "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
    50. % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
    51. )
    [Epoch 0/2] [Batch 937/938] [D loss: 1.358861, acc: 50%] [G loss: 0.671799]
    [Epoch 1/2] [Batch 937/938] [D loss: 1.343094, acc: 50%] [G loss: 0.681119]
  • 相关阅读:
    迅狐短视频矩阵管理系统核心功能
    Git 学习笔记 | Git 基本操作命令
    A_A02_003 ST-LINK驱动安装
    汉纳西点:100天成功打造大连行业最大单体店,创造一个商业传奇
    13年测试老鸟,性能测试内存泄露——案例分析(超细整理)
    [LeetCode]剑指 Offer 42. 连续子数组的最大和
    Python 数据可视化:Seaborn 库的使用
    一键重装win7系统详细教程
    unity 从UI上拖出3D物体,(2D转3D)
    基于共词分析的中国近代史实体关系图构建(毕业设计:图数据渲染)
  • 原文地址:https://blog.csdn.net/m0_62800398/article/details/134149536