• 深度学习(PyTorch)——生成对抗网络(GAN)


    一、GAN的基本概念

    GAN是由Ian Goodfellow于2014年首次提出,学习GAN的初衷,即生成不存在于真实世界的数据。类似于AI具有创造力和想象力。

    GAN有两大护法G和D:

    G是generator,生成器: 负责凭空捏造数据出来

    D是discriminator,判别器: 负责判断数据是不是真数据

    这样可以简单看作是两个网络的博弈过程。在原始的GAN论文里面,G和D都是两个多层感知机网络。GAN操作的数据不一定非是图像数据,在此用图像数据为例解释以下GAN:

    上图中,z是随机噪声(随机生成的一些数,也是GAN生成图像的源头)。D通过真图和假图的数据,进行一个二分类神经网络训练。G根据一串随机数就可以捏造出一个"假图像"出来,用这些假图去欺骗D,D负责辨别这是真图还是假图,会给出一个score。比如,G生成了一张图,在D这里评分很高,说明G生成能力是很成功的;若D给出的评分不高,可以有效区分真假图,则G的效果还不太好,需要调整参数。

    二、GAN的基本原理

    GAN的训练在同一轮梯度反转的过程中可以细分为2步:(1)先训练D;(2)再训练G。注意,不是等所有的D训练好了才开始训练G,因为D的训练也需要上一轮梯度反转中的G的输出值作为输入。

    当训练D的时候:上一轮G产生的图片和真实图片,直接拼接在一起作为x。然后按顺序摆放成0和1,假图对应0,真图对应1。然后就可以通过D,x输入生成一个score(从0到1之间的数),通过score和y组成的损失函数,就可以进行梯度反传了。(我在图片上举的例子是batch = 1,len(y)=2*batch,训练时通常可以取较大的batch)

    当训练G的时候:需要把G和D当作一个整体,这里取名叫做’D_on_G’。这个整体(简称DG系统)的输出仍然是score。输入一组随机向量z,就可以在G生成一张图,通过D对生成的这张图进行打分得到score,这就是DG系统的前向过程。score=1就是DG系统需要优化的目标,score和y=1之间的差异可以组成损失函数,然后可以采用反向传播梯度。注意,这里的D的参数是不可训练的。这样就能保证G的训练是符合D的打分标准的。这就好比:如果你参加考试,你别指望能改变老师的评分标准。

    GAN模型的目标函数如下:

    在这里,训练网络D使得最大概率地分对训练样本的标签(最大化log D(x)和log(1—D(G(z)))),训练网络G最小化log(1-D(G(z))),即最大化D的损失。而训练过程中固定一方,更新另一个网络的参数,交替迭代,使得对方的错误最大化,最终,G 能估测出样本数据的分布,也就是生成的样本更加的真实。

    三、案例、用GAN对输入的随机数生成手写字体的图片

    程序如下:

    1. import torch
    2. import torch.nn as nn
    3. import torch.optim as optim
    4. import torch.nn.functional as F
    5. import numpy as np
    6. import matplotlib.pyplot as plt
    7. import torchvision
    8. from torchvision import transforms
    9. import os
    10. import pandas
    11. os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
    12. #对数据做归一化(-1 --- 1
    13. tranform = transforms.Compose([
    14. transforms.ToTensor(), # 0-1归一化:channel,high,width
    15. transforms.Normalize(0.5,0.5), #均值,方差均为0.5
    16. ])
    17. train_ds = torchvision.datasets.MNIST("data",train = True,transform = tranform,download=True)
    18. dataloader = torch.utils.data.DataLoader(train_ds,batch_size = 64,shuffle = True)
    19. imgs, _ =next(iter(dataloader))
    20. print(imgs.shape) #torch.Size([64, 1, 28, 28])
    21. # batch_size=64 图片大小:1,28,28
    22. # 定义生成器
    23. # 输入是长度为100的噪声(正态分布随机数)
    24. # 输出为(1,28,28)的图片
    25. # linear1:100---256
    26. # linear2:256---512
    27. # linear1:512---28*28
    28. # linear2:28*28---(1,28,28)
    29. class Generator(nn.Module):
    30. def __init__(self):
    31. super(Generator,self).__init__()
    32. self.main = nn.Sequential(
    33. nn.Linear(100,256),
    34. nn.ReLU(),
    35. nn.Linear(256,512),
    36. nn.ReLU(),
    37. nn.Linear(512, 28*28),
    38. nn.Tanh()
    39. )
    40. def forward(self,x): # x表示长度为100的噪声输入
    41. img = self.main(x)
    42. img = img.view(-1,28,28)
    43. return img
    44. # 定义判别器
    45. # 输入为(1,28,28)的图片,输出为二分类的概率值,输出使用sigmoid激活
    46. # BCELose计算交叉熵损失
    47. # nn.LeakyReLU f(x):x>0,输出x,如果x<0,输出a*x,a表示一个很小的斜率,比如0.1
    48. #判别器中一般推荐使用nn.LeakyReLU
    49. class Discriminator(nn.Module):
    50. def __init__(self):
    51. super(Discriminator, self).__init__()
    52. self.main = nn.Sequential(
    53. nn.Linear(28*28,512),
    54. nn.LeakyReLU(),
    55. nn.Linear(512,256),
    56. nn.LeakyReLU(),
    57. nn.Linear(256, 1),
    58. nn.Sigmoid()
    59. )
    60. def forward(self,x):
    61. x = x.view(-1,28*28)
    62. x = self.main(x)
    63. return x
    64. # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    65. device = 'cuda' if 0 else 'cpu'
    66. gen = Generator().to(device)
    67. dis = Discriminator().to(device)
    68. d_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
    69. g_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
    70. loss_fn = torch.nn.BCELoss()
    71. def gen_img_plot(model,test_input):
    72. prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    73. # detach()截断梯度,np.squeeze可以去掉维度为1
    74. fig = plt.figure(figsize=(4,4))
    75. for i in range(16): # prediction.size(0)=16
    76. plt.subplot(4,4,i+1)
    77. plt.imshow((prediction[i]+1)/2) # tanh 得到的是-1 - 1之间,-》0-1之间
    78. plt.axis("off")
    79. plt.show()
    80. test_input = torch.randn(16,100,device=device) # 16个长度为100的正态随机数
    81. # print(test_input)
    82. D_loss = []
    83. G_loss = []
    84. # 训练循环
    85. for epoch in range(20):
    86. d_epoch_loss = 0
    87. g_epoch_loss = 0
    88. count = len(dataloader) # len(dataloader)返回批次数
    89. # len(dataset)返回样本数
    90. for step,(img,_) in enumerate(dataloader):
    91. img = img.to(device)
    92. size = img.size(0)
    93. random_noise = torch.randn(size,100,device=device)
    94. d_optim.zero_grad()
    95. real_output = dis(img) # 对判别器输入真实图片 real_output对真实图片的预测结果
    96. #得到判别器在真实图像上面的损失
    97. d_real_loss = loss_fn(real_output,torch.ones_like(real_output))
    98. d_real_loss.backward()
    99. gen_img = gen(random_noise)
    100. #detach()截断生成器梯度,更新判别器梯度
    101. fake_output = dis(gen_img.detach()) # 判别器输入生成图片。fake_output对生成图片的预测
    102. # 得到判别器在生成图像上面的损失
    103. d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output))
    104. d_fake_loss.backward()
    105. d_loss = d_fake_loss + d_real_loss
    106. d_optim.step()
    107. g_optim.zero_grad()
    108. fake_output = dis(gen_img)
    109. # 得到生成器的损失
    110. g_loss = loss_fn(fake_output,torch.ones_like(fake_output))
    111. g_loss.backward()
    112. g_optim.step()
    113. with torch.no_grad():
    114. d_epoch_loss +=d_loss
    115. g_epoch_loss +=g_loss
    116. with torch.no_grad():
    117. d_epoch_loss /=count
    118. g_epoch_loss /=count
    119. D_loss.append(d_epoch_loss)
    120. G_loss.append(g_epoch_loss)
    121. print("Epoch",epoch)
    122. gen_img_plot(gen,test_input)
  • 相关阅读:
    学习C++第二课
    校园网web免认真,大量服务器
    2022/07/04学习记录
    SparkSQL系列-1、快速入门
    全国职业技能大赛云计算--高职组赛题卷①(私有云)
    Vue(十)——页面路由(2)
    make&Makefile
    【KMP算法】大白话式详细图文解析(附代码)
    【2020】【论文笔记】使用CMOS和散射光学的THz分光——
    Maven的介绍和使用
  • 原文地址:https://blog.csdn.net/qq_42233059/article/details/126579791