生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了
1、AI作家
2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节;
3、进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。
那到底是怎么实现的呢?
GAN中有两大组成部分G和D
G是generator,生成器: 负责凭空捏造数据出来
D是discriminator,判别器: 负责判断数据是不是真数据
示例图如下:
给一个随机噪声z,通过G生成一张假图,然后用D去分辨是真图还是假图。假设G生成了一张图,在D那里的得分很高,那么G就很成功的骗过了D,如果D很轻松的分辨出了假图,那么G的效果不好,那么就需要调整参数了。
G和D是两个单独的网络,那么他们的参数都是训练好的吗?并不是,两个网络的参数是需要在博弈的过程中分别优化的。
下面就是一个训练的过程:
GAN在一轮反向传播中分为两步,先训练D在训练G。
训练D时,上一轮G产生的图片,和真实图片一起作为x进行输入,假图为0,真图标签为1,通过x生成一个score,通过score和标签y计算损失,就可以进行反向传播了。
训练G时,G和D是一个整体,取名为D_on_G。输入随机噪声,G产生一个假图,D去分辨,score = 1就是需要我们需要优化的目标,意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的,我们通过反向传播来优化G,让他生成更加真实的图片。这就好比:如果你参加考试,你别指望能改变老师的评分标准
GAN无监督学习,(cGAN是有监督的),以后会学习的。怎么理解无监督学习呢?这里给的真图是没有经过人工标注的,只知道这是真的,D是不知道这是什么的,只需要分辨真假。G也不知道生成了什么,只需要学真图去骗D。
具体如何实施呢?
- import os
- import torch
- import torchvision
- import torch.nn as nn
- from torchvision import transforms
- from torchvision.utils import save_image
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- latent_size = 64
- hidden_size = 256
- image_size = 784
- num_epochs = 200
- batch_size = 100
- sample_dir = 'samples'
注意这里有个归一化的过程,MNIST是单通道,但是如果mean=(0.5,0.5,0.5)会报错,因为是对3通道操作 。
- if not os.path.exists(sample_dir):
- os.makedirs(sample_dir)
-
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=(0.5,), # 3 for RGB channels
- std=(0.5,))])
-
- # MNIST dataset
- mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
- # Data loader
- data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)
定义生成器和判别器:
生成器:可以看到输入的维度为64,是一组噪声图像,通过生成器将特征扩大到了MNIST图像大小784。
判别器:输入维度为图像大小,最后输出特征个数为1,采用sigmoid激活(不用softmax的)
- # Discriminator
- D = nn.Sequential(
- nn.Linear(image_size, hidden_size),
- nn.LeakyReLU(0.2),
- nn.Linear(hidden_size, hidden_size),
- nn.LeakyReLU(0.2),
- nn.Linear(hidden_size, 1),
- nn.Sigmoid())
-
-
- # Generator
- G = nn.Sequential(
- nn.Linear(latent_size, hidden_size),
- nn.ReLU(),
- nn.Linear(hidden_size, hidden_size),
- nn.ReLU(),
- nn.Linear(hidden_size, image_size),
- nn.Tanh())
- # Device setting
- D = D.to(device)
- G = G.to(device)
-
- # Binary cross entropy loss and optimizer
- criterion = nn.BCELoss()
- d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
- g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
-
-
- def denorm(x):
- out = (x + 1) / 2
- return out.clamp(0, 1)
-
- def reset_grad():
- d_optimizer.zero_grad()
- g_optimizer.zero_grad()
重点看训练部分,我们到底是如何来训练GAN的。
判别器部分:判别器的损失值分为两部分,(一)将mini_batch定义为正样本,告诉他我是正品,所以设置标签为1。优化判别器判断正品的能力;(二)生成一幅赝品,再给判别器判别,这时候赝品的标签为0,优化判断赝品的能力。所以总损失为这两部分之和,计算梯度,优化判别器参数。
G_on_D:输入一个噪声,让生成器生成一幅图像,然后让D去判别,计算和正品之间的距离,即损失。反向传播,优化G的参数。
- # Start training
- total_step = len(data_loader)
- for epoch in range(num_epochs):
- for i, (images, _) in enumerate(data_loader):
- images = images.reshape(batch_size, -1).to(device)
-
- # Create the labels which are later used as input for the BCE loss
- real_labels = torch.ones(batch_size, 1).to(device)
- fake_labels = torch.zeros(batch_size, 1).to(device)
-
- # ================================================================== #
- # Train the discriminator #
- # ================================================================== #
-
- # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
- # Second term of the loss is always zero since real_labels == 1
- outputs = D(images)
- d_loss_real = criterion(outputs, real_labels)
- real_score = outputs
-
- # Compute BCELoss using fake images
- # First term of the loss is always zero since fake_labels == 0
- z = torch.randn(batch_size, latent_size).to(device)
- fake_images = G(z)
- outputs = D(fake_images)
- d_loss_fake = criterion(outputs, fake_labels)
- fake_score = outputs
-
- # Backprop and optimize
- d_loss = d_loss_real + d_loss_fake
- reset_grad()
- d_loss.backward()
- d_optimizer.step()
-
- # ================================================================== #
- # Train the generator #
- # ================================================================== #
-
- # Compute loss with fake images
- z = torch.randn(batch_size, latent_size).to(device)
- fake_images = G(z)
- outputs = D(fake_images)
-
- # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
- # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
- g_loss = criterion(outputs, real_labels)
-
- # Backprop and optimize
- reset_grad()
- g_loss.backward()
- g_optimizer.step()
-
- if (i+1) % 200 == 0:
- print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
- .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
- real_score.mean().item(), fake_score.mean().item()))
-
- # Save real images
- if (epoch+1) == 1:
- images = images.reshape(images.size(0), 1, 28, 28)
- save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
-
- # Save sampled images
- fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
- save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
训练完了怎么用?
只要用我们的生成器就可以随意生成了。
- import matplotlib.pyplot as plt
- z = torch.randn(1,latent_size).to(device)
- output = G(z)
- plt.imshow(output.cpu().data.numpy().reshape(28,28),cmap='gray')
- plt.show()
下面就是随机生成的图像了!