• Pytorch Advanced(一) Generative Adversarial Networks


    生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了

    参考

    1、AI作家
    2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节;
    3、进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。

    那到底是怎么实现的呢?


    GAN中有两大组成部分G和D

    G是generator,生成器: 负责凭空捏造数据出来

    D是discriminator,判别器: 负责判断数据是不是真数据

    示例图如下:

    给一个随机噪声z,通过G生成一张假图,然后用D去分辨是真图还是假图。假设G生成了一张图,在D那里的得分很高,那么G就很成功的骗过了D,如果D很轻松的分辨出了假图,那么G的效果不好,那么就需要调整参数了。


    G和D是两个单独的网络,那么他们的参数都是训练好的吗?并不是,两个网络的参数是需要在博弈的过程中分别优化的。

    下面就是一个训练的过程:

    GAN在一轮反向传播中分为两步,先训练D在训练G。

    训练D时,上一轮G产生的图片,和真实图片一起作为x进行输入,假图为0,真图标签为1,通过x生成一个score,通过score和标签y计算损失,就可以进行反向传播了。

    训练G时,G和D是一个整体,取名为D_on_G。输入随机噪声,G产生一个假图,D去分辨,score = 1就是需要我们需要优化的目标,意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的,我们通过反向传播来优化G,让他生成更加真实的图片。这就好比:如果你参加考试,你别指望能改变老师的评分标准


    GAN无监督学习,(cGAN是有监督的),以后会学习的。怎么理解无监督学习呢?这里给的真图是没有经过人工标注的,只知道这是真的,D是不知道这是什么的,只需要分辨真假。G也不知道生成了什么,只需要学真图去骗D。


    具体如何实施呢?

    1. import os
    2. import torch
    3. import torchvision
    4. import torch.nn as nn
    5. from torchvision import transforms
    6. from torchvision.utils import save_image
    7. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    1. latent_size = 64
    2. hidden_size = 256
    3. image_size = 784
    4. num_epochs = 200
    5. batch_size = 100
    6. sample_dir = 'samples'

    注意这里有个归一化的过程,MNIST是单通道,但是如果mean=(0.5,0.5,0.5)会报错,因为是对3通道操作 。

    1. if not os.path.exists(sample_dir):
    2. os.makedirs(sample_dir)
    3. transform = transforms.Compose([
    4. transforms.ToTensor(),
    5. transforms.Normalize(mean=(0.5,), # 3 for RGB channels
    6. std=(0.5,))])
    7. # MNIST dataset
    8. mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
    9. # Data loader
    10. data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)

    定义生成器和判别器:

    生成器:可以看到输入的维度为64,是一组噪声图像,通过生成器将特征扩大到了MNIST图像大小784。

    判别器:输入维度为图像大小,最后输出特征个数为1,采用sigmoid激活(不用softmax的)

    1. # Discriminator
    2. D = nn.Sequential(
    3. nn.Linear(image_size, hidden_size),
    4. nn.LeakyReLU(0.2),
    5. nn.Linear(hidden_size, hidden_size),
    6. nn.LeakyReLU(0.2),
    7. nn.Linear(hidden_size, 1),
    8. nn.Sigmoid())
    9. # Generator
    10. G = nn.Sequential(
    11. nn.Linear(latent_size, hidden_size),
    12. nn.ReLU(),
    13. nn.Linear(hidden_size, hidden_size),
    14. nn.ReLU(),
    15. nn.Linear(hidden_size, image_size),
    16. nn.Tanh())
    1. # Device setting
    2. D = D.to(device)
    3. G = G.to(device)
    4. # Binary cross entropy loss and optimizer
    5. criterion = nn.BCELoss()
    6. d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
    7. g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
    8. def denorm(x):
    9. out = (x + 1) / 2
    10. return out.clamp(0, 1)
    11. def reset_grad():
    12. d_optimizer.zero_grad()
    13. g_optimizer.zero_grad()

     重点看训练部分,我们到底是如何来训练GAN的。

    判别器部分:判别器的损失值分为两部分,(一)将mini_batch定义为正样本,告诉他我是正品,所以设置标签为1。优化判别器判断正品的能力;(二)生成一幅赝品,再给判别器判别,这时候赝品的标签为0,优化判断赝品的能力。所以总损失为这两部分之和,计算梯度,优化判别器参数。

    G_on_D:输入一个噪声,让生成器生成一幅图像,然后让D去判别,计算和正品之间的距离,即损失。反向传播,优化G的参数。

    1. # Start training
    2. total_step = len(data_loader)
    3. for epoch in range(num_epochs):
    4. for i, (images, _) in enumerate(data_loader):
    5. images = images.reshape(batch_size, -1).to(device)
    6. # Create the labels which are later used as input for the BCE loss
    7. real_labels = torch.ones(batch_size, 1).to(device)
    8. fake_labels = torch.zeros(batch_size, 1).to(device)
    9. # ================================================================== #
    10. # Train the discriminator #
    11. # ================================================================== #
    12. # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
    13. # Second term of the loss is always zero since real_labels == 1
    14. outputs = D(images)
    15. d_loss_real = criterion(outputs, real_labels)
    16. real_score = outputs
    17. # Compute BCELoss using fake images
    18. # First term of the loss is always zero since fake_labels == 0
    19. z = torch.randn(batch_size, latent_size).to(device)
    20. fake_images = G(z)
    21. outputs = D(fake_images)
    22. d_loss_fake = criterion(outputs, fake_labels)
    23. fake_score = outputs
    24. # Backprop and optimize
    25. d_loss = d_loss_real + d_loss_fake
    26. reset_grad()
    27. d_loss.backward()
    28. d_optimizer.step()
    29. # ================================================================== #
    30. # Train the generator #
    31. # ================================================================== #
    32. # Compute loss with fake images
    33. z = torch.randn(batch_size, latent_size).to(device)
    34. fake_images = G(z)
    35. outputs = D(fake_images)
    36. # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
    37. # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
    38. g_loss = criterion(outputs, real_labels)
    39. # Backprop and optimize
    40. reset_grad()
    41. g_loss.backward()
    42. g_optimizer.step()
    43. if (i+1) % 200 == 0:
    44. print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
    45. .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
    46. real_score.mean().item(), fake_score.mean().item()))
    47. # Save real images
    48. if (epoch+1) == 1:
    49. images = images.reshape(images.size(0), 1, 28, 28)
    50. save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    51. # Save sampled images
    52. fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    53. save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

    训练完了怎么用?

    只要用我们的生成器就可以随意生成了。

    1. import matplotlib.pyplot as plt
    2. z = torch.randn(1,latent_size).to(device)
    3. output = G(z)
    4. plt.imshow(output.cpu().data.numpy().reshape(28,28),cmap='gray')
    5. plt.show()

     下面就是随机生成的图像了!

      

  • 相关阅读:
    Nginx实现tcp代理并支持TLS加密实验
    【code-server】Code-Server 安装部署
    新款吉利星越L正式上市,媒介盒子多家媒体报道
    爬虫软件是什么意思
    怎样生成分布式的流水ID
    iPhone15线下购买,苹果零售店前门店排长队
    SQL 注入绕过(四)
    广州华锐互动:VR模拟高楼层建筑应急逃生,提供身临其境的虚拟体验
    QT延时/等待
    select\poll\epoll的区别
  • 原文地址:https://blog.csdn.net/qq_41828351/article/details/90813916