目录
生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。
针对原始GAN不能生成具有特定属性的图片的问题,Mehdi Mirza等人提出了CGAN,其核心在于将属性信息y融入生成器和判别器中,属性y可以是任何标签的信息,例如图像的类别,人脸图像的面部表情等。
CGAN的中心思想是希望可以控制GAN生成的图片,而不是单纯的随机生成图片。具体来说,Conditional GAN在生成器和判别器的输入中添加了额外的条件信息,生成器生成的图片只有足够真实且条件相符,才能够通过判别器。
从公式上来看,CGAN相当于在原始GAN的基础上对生成器部分和判别器部分都加了一个条件。

从模型上来看,如下图所示

为了实现条件GAN的目的,生成网络和判别网络的原理和训练方式均要有所改变。模型部分,在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器。
- #导入库
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils import data
- import torchvision #加载图片
- from torchvision import transforms #图片变换
-
- import numpy as np
- import matplotlib.pyplot as plt #绘图
- import os
- import glob
- from PIL import Image
-
- #独热编码
- def one_hot(x,class_count=10):
- return torch.eye(class_count)[x,:]
-
- transform = transforms.Compose([
- transforms.ToTensor(), #取值范围会被归一化到(0,1)之间
- transforms.Normalize(mean=0.5,std=0.5) #设置均值和方差均为0.5
- ])
-
-
- #加载数据集
- dataset = torchvision.datasets.MNIST('data',
- train=True,
- transform=transform,
- target_transform = one_hot,
- download = True)
- dl = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle = True)
-
- #定义生成器
- class Generator(nn.Module):
- def __init__(self):
- super(Generator,self).__init__()
- self.linear1 = nn.Linear(100,128*7*7)
- self.bn1=nn.BatchNorm1d(128*7*7)
- self.linear2 = nn.Linear(10,128*7*7)
- self.bn2=nn.BatchNorm1d(128*7*7)
-
- self.deconv1 = nn.ConvTranspose2d(256,128,
- kernel_size=(3,3),
- stride=1,
- padding=1) #生成(128,7,7)的二维图像
- self.bn3=nn.BatchNorm2d(128)
- self.deconv2 = nn.ConvTranspose2d(128,64,
- kernel_size=(4,4),
- stride=2,
- padding=1) #生成(64,14,14)的二维图像
- self.bn4=nn.BatchNorm2d(64)
- self.deconv3 = nn.ConvTranspose2d(64,1,
- kernel_size=(4,4),
- stride=2,
- padding=1) #生成(1,28,28)的二维图像
-
- def forward(self,x1,x2):
- x1=F.relu(self.linear1(x1))
- x1=self.bn1(x1)
- x1=x1.view(-1,128,7,7)
- x2=F.relu(self.linear2(x2))
- x2=self.bn2(x2)
- x2=x2.view(-1,128,7,7)
- x=torch.cat([x1,x2],axis=1) #batch, 256, 7, 7
- x=x.view(-1,256,7,7)
- x=F.relu(self.deconv1(x))
- x=self.bn3(x)
- x=F.relu(self.deconv2(x))
- x=self.bn4(x)
- x=torch.tanh(self.deconv3(x))
- return x
-
-
- #定义判别器
- #输入:1,28,28图片和长度为10的condition
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator,self).__init__()
- self.linear = nn.Linear(10,1*28*28)
- self.conv1 = nn.Conv2d(2,64,kernel_size=3,stride=2)
- self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2)
- self.bn = nn.BatchNorm2d(128)
- self.fc = nn.Linear(128*6*6,1)
- def forward(self,x1,x2): #x1代表label,x2代表image
- x1=F.leaky_relu(self.linear(x1))
- x1=x1.view(-1,1,28,28)
- x=torch.cat([x1,x2],axis=1) #shape:batch,2,28,28
- x= F.dropout2d(F.leaky_relu(self.conv1(x)))
- x= F.dropout2d(F.leaky_relu(self.conv2(x)) ) #(batch,128,6,6)
- x = self.bn(x)
- x = x.view(-1,128*6*6) #展平
- x = torch.sigmoid(self.fc(x))
- return x
-
- #模型训练
- #设备的配置
- device='cuda' if torch.cuda.is_available() else 'cpu'
- #初化生成器和判别器把他们放到相应的设备上
- gen = Generator().to(device)
- dis = Discriminator().to(device)
- #交叉熵损失函数
- loss_fn = torch.nn.BCELoss()
- #训练器的优化器
- d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5)
- #训练生成器的优化器
- g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-4)
- #定义可视化函数
- def generate_and_save_images(model,epoch,label_input,noise_input):
- prediction = np.squeeze(model(noise_input,label_input).cpu().numpy())
- fig = plt.figure(figsize=(4,4))
- for i in range(prediction.shape[0]):
- plt.subplot(4,4,i+1)
- plt.imshow((prediction[i]+1)/2,cmap='gray')
- plt.axis('off')
- plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
- plt.show()
- #设置生成绘图图片的随机张量,这里可视化16张图片
- #生成16个长度为100的随机正态分布张量
- noise_seed = torch.randn(16,100,device=device)
- label_seed = torch.randint(0,10,size=(16,))
- label_seed_onehot = one_hot(label_seed).to(device)
-
- D_loss = [] #记录训练过程中判别器的损失
- G_loss = [] #记录训练过程中生成器的损失
- #训练循环
- for epoch in range(10):
- #初始化损失值
- D_epoch_loss = 0
- G_epoch_loss = 0
- count = len(dl.dataset) #返回批次数
- #对数据集进行迭代
- for step,(img,label) in enumerate(dl):
- img =img.to(device) #把数据放到设备上
- label = label.to(device)
- size = img.shape[0] #img的第一位是size,获取批次的大小
- random_seed = torch.randn(size,100,device=device)
-
- #判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化
- d_optimizer.zero_grad()#梯度归零
- #判别器对于真实图片产生的损失
- real_output = dis(label,img) #判别器输入真实的图片,real_output对真实图片的预测结果
- d_real_loss = loss_fn(real_output,
- torch.ones_like(real_output,device=device)
- )
- d_real_loss.backward()#计算梯度
-
- #在生成器上去计算生成器的损失,优化目标是判别器上的参数
- generated_img = gen(random_seed,label) #得到生成的图片
- #因为优化目标是判别器,所以对生成器上的优化目标进行截断
- fake_output = dis(label,generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了
- #判别器在生成图像上产生的损失
- d_fake_loss = loss_fn(fake_output,
- torch.zeros_like(fake_output,device=device)
- )
- d_fake_loss.backward()
- #判别器损失
- disc_loss = d_real_loss + d_fake_loss
- #判别器优化
- d_optimizer.step()
-
-
- #生成器上损失的构建和优化
- g_optimizer.zero_grad() #先将生成器上的梯度置零
- fake_output = dis(label,generated_img)
- gen_loss = loss_fn(fake_output,
- torch.ones_like(fake_output,device=device)
- ) #生成器损失
- gen_loss.backward()
- g_optimizer.step()
- #累计每一个批次的loss
- with torch.no_grad():
- D_epoch_loss +=disc_loss
- G_epoch_loss +=gen_loss
- #求平均损失
- with torch.no_grad():
- D_epoch_loss /=count
- G_epoch_loss /=count
- D_loss.append(D_epoch_loss)
- G_loss.append(G_epoch_loss)
- #训练完一个Epoch,打印提示并绘制生成的图片
- print("Epoch:",epoch)
- print(label_seed)
- generate_and_save_images(gen,epoch,label_seed_onehot,noise_seed)
因篇幅有限,只展示一部分运行结果





CGAN生成的图像虽然有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像网络时对属性特征的处理方法均受到CGAN启发。
希望我的文章能对你有所帮助。欢迎👍点赞 ,📝评论,🌟关注,⭐️收藏
