• GAN基础知识及代码


    GAN也叫做生成对抗网络,分为两部分,一个是生成网络G,一个是对抗网络D。生成网络和对抗网络进行竞争,生成模型可以被认为是造假者,他们试图制造假币并在不被发现的情况下使用它,而鉴别模型类似于警察,视图发现假币。在这个游戏中,竞争促使两个团队改进他们的方法,直到冒充的产品和正品无法区分。

    生成模型和判别模型都是多层感知器。

    噪声就是随机生成的数,通过生成器随机生成一张图。(所以生成器只能随机生成图像,不能指定一些条件)

    判别器的作用是尽可能的把真实数据集和生成数据集区分开,对于真实数据希望输出1,对于生成数据希望输出0。

    相反,生成器希望判别器读入生成数据,输出1。

    损失函数:

    简化一点就是 D(x) +( 1-D(G(z)) ) , log是单调递增函数,此处的作用是放大损失。

    对于生成器G,希望这个函数尽可能小,即D(x)接近0,1-D(G(z))接近0,即D(G(z))接近1.  事实上生成器不管D(x)是否是0,只要确保D(G(z))接近1,即生成的图像判别器判成了真实图像。

    标准代码 - (pytorch)

    在MNIST手写数字数据集上训练。

    1. import torch.nn as nn
    2. import torch.nn.functional as F
    3. import torch.optim as optim
    4. import numpy as np
    5. import matplotlib.pyplot as plt
    6. import torchvision
    7. from torchvision import transforms

    1. 数据准备

    对真实数据做归一化(-1,1),gan要求的,因为生成器生成的数据是(-1,1),保持两个数据分布一样

    1. transform = transforms.Compose([
    2. transforms.ToTensor(), # 归一化为0~1
    3. transforms.Normalize(0.5,0.5) # 归一化为-1~1
    4. ])
    5. train_ds = torchvision.datasets.MNIST('datasets', # 下载到那个目录下
    6. train=True,
    7. transform=transform,
    8. download=True)
    9. dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64,shuffle=True)
    10. imgs,_ = next(iter(dataloader))
    11. imgs.shape
    12. # torch.Size([64, 1, 28, 28])

    2. 定义生成器

    输入是长度100的噪声z(正态分布随机数)
    输出为(1,28,28)的图片,和MNIST数据集保持一致
    Linear1: 100->256
    Linear2: 256->512
    Linear3: 512->2828
    reshape: 28
    28->(1,28,28)

    1. class Generator(nn.Module):
    2. def __init__(self):
    3. super(Generator, self).__init__() # 继承父类
    4. self.main = nn.Sequential(
    5. nn.Linear(100,256), nn.ReLU(),
    6. nn.Linear(256,512), nn.ReLU(),
    7. nn.Linear(512,28*28),
    8. nn.Tanh() # 最后必须用tanh,把数据分布到(-1,1)之间
    9. )
    10. def forward(self, x): # x表示长度为100的噪声输入
    11. img = self.main(x)
    12. img = img.view(-1,28,28,1) # 方便等会绘图
    13. return img

    3. 定义判别器

    输入为(1,28,28)的mnist图片
    输出为二分类的概率,使用sigmoid激活,范围为0~1
    BCEloss计算交叉熵损失
    判别器推荐使用LeakReLU激活

    1. class Discriminator(nn.Module):
    2. def __init__(self):
    3. super(Discriminator,self).__init__()
    4. self.main = nn.Sequential(
    5. nn.Linear(28*28,512),
    6. nn.LeakyReLU(), # x小于零是是一个很小的值不是0,x大于0是还是x
    7. nn.Linear(512,256),
    8. nn.LeakyReLU(),
    9. nn.Linear(256,1),
    10. nn.Sigmoid() # 保证输出范围为(0,1)的概率
    11. )
    12. def forward(self, x): # x表示28*28的mnist图片
    13. img = x.view(-1,28*28)
    14. img = self.main(img)
    15. return img

    4. 初始化模型、优化器、损失函数

    1. device = 'cuda' if torch.cuda.is_available() else 'cpu'
    2. print('training on ',device)
    3. # 模型
    4. gen = Generator().to(device)
    5. dis = Discriminator().to(device)
    6. # 优化器
    7. g_opt = torch.optim.Adam(gen.parameters(),lr=0.0001)
    8. d_opt = torch.optim.Adam(dis.parameters(),lr=0.0001)
    9. # 损失
    10. loss = torch.nn.BCELoss()

    5. 绘图函数

    随时看到生成器生成的图像

    1. def gen_img_plot(model, test_input):
    2. prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    3. fig = plt.figure(figsize=(4,4))
    4. for i in range(16):
    5. plt.subplot(4,4,i+1) # 四行四列的第一个
    6. # imshow函数绘图的输入是(0,1)的float,或者(1,256)的int
    7. # 但prediction是tanh出来的范围是[-1,1]没法绘图,需要转成0~1(即加1除2)。
    8. plt.imshow( (prediction[i]+1)/2 )
    9. plt.axis('off')
    10. plt.show()
    11. test_input = torch.randn(16, 100, device=device)

    6. GAN训练

    1. D_loss = []
    2. G_loss = []
    3. epochs = 40
    4. for epoch in range(epochs):
    5. d_epoch_loss = 0
    6. g_epoch_loss = 0
    7. count = len(dataloader) # 一个epoch的大小
    8. for step, (img, _) in enumerate(dataloader):
    9. img = img.to(device) # 一个批次的图片
    10. size = img.size(0) # 和和图片对应的原始噪音
    11. random_noise = torch.randn(size, 100, device=device)
    12. gen_img = gen(random_noise) # 生成的图像
    13. d_opt.zero_grad()
    14. real_output = dis(img) # 判别器输入真实图片,对真实图片的预测结果,希望是1
    15. # 判别器在真实图像上的损失
    16. d_real_loss = loss(real_output, torch.ones_like(real_output)) # size一样全一的tensor
    17. d_real_loss.backward()
    18. g_opt.zero_grad()
    19. # 记得切断生成器的梯度
    20. fake_output = dis(gen_img.detach()) # 判别器输入生成图片,对生成图片的预测结果,希望是0
    21. # 判别器在生成图像上的损失
    22. d_fake_loss = loss(fake_output, torch.zeros_like(fake_output)) # size一样全一的tensor
    23. d_fake_loss.backward()
    24. d_loss = d_real_loss + d_fake_loss
    25. d_opt.step()
    26. # 生成器的损失
    27. g_opt.zero_grad()
    28. fake_output = dis(gen_img) # 希望被判定为1
    29. g_loss = loss(fake_output, torch.ones_like(fake_output))
    30. g_loss.backward()
    31. g_opt.step()
    32. # 每个epoch内的loss累加,循环外再除epoch大小,得到平均loss
    33. with torch.no_grad():
    34. d_epoch_loss += d_loss
    35. g_epoch_loss += g_loss
    36. # 一个epoch训练完成
    37. with torch.no_grad():
    38. d_epoch_loss /= count
    39. g_epoch_loss /= count
    40. D_loss.append(d_epoch_loss)
    41. G_loss.append(g_epoch_loss)
    42. print('Epoch: ',epoch)
    43. gen_img_plot(gen, test_input)

    结果:

     

     可以看到效果还算可以,这是2014年提出的最基础的GAN,后续要有若干改进的工作,效果更好,有机会再学。

  • 相关阅读:
    计算机网络复习笔记——运输层
    【博客461】BGP(边界网关协议)-----EBGP多跳和指定更新源问题分析
    [附源码]计算机毕业设计JAVAjsp铁路集装箱物流管理信息系统
    Vue3实战(1)
    关于 a (链接)标签 里面包含图片会被撑大的解决方法、a标签会撑大的解决方法
    牵手乌金石,共创财富梦—全国市场盛大启动
    强化学习与视觉语言模型之间的碰撞,UC伯克利提出语言奖励调节LAMP框架
    SSM之Mybatis概览
    如何利用AWS CloudFront 自定义设置SSL
    Python字典
  • 原文地址:https://blog.csdn.net/weixin_43828245/article/details/127037756