• CGAN理论讲解及代码实现


    目录

    1.原始GAN的缺点

    2.CGAN中心思想

    3.原始GAN和CGAN的区别

    4.CGAN代码实现 

    5.运行结果

    6.CGAN缺陷


    1.原始GAN的缺点

    生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。

    针对原始GAN不能生成具有特定属性的图片的问题,Mehdi Mirza等人提出了CGAN,其核心在于将属性信息y融入生成器和判别器中,属性y可以是任何标签的信息,例如图像的类别,人脸图像的面部表情等。

    2.CGAN中心思想

    CGAN的中心思想是希望可以控制GAN生成的图片,而不是单纯的随机生成图片。具体来说,Conditional  GAN在生成器和判别器的输入中添加了额外的条件信息,生成器生成的图片只有足够真实且条件相符,才能够通过判别器。

    3.原始GAN和CGAN的区别

    从公式上来看,CGAN相当于在原始GAN的基础上对生成器部分和判别器部分都加了一个条件。

    从模型上来看,如下图所示

    为了实现条件GAN的目的,生成网络和判别网络的原理和训练方式均要有所改变。模型部分,在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器。

    4.CGAN代码实现 

    1. #导入库
    2. import torch
    3. import torch.nn as nn
    4. import torch.nn.functional as F
    5. from torch.utils import data
    6. import torchvision #加载图片
    7. from torchvision import transforms #图片变换
    8. import numpy as np
    9. import matplotlib.pyplot as plt #绘图
    10. import os
    11. import glob
    12. from PIL import Image
    13. #独热编码
    14. def one_hot(x,class_count=10):
    15. return torch.eye(class_count)[x,:]
    16. transform = transforms.Compose([
    17. transforms.ToTensor(), #取值范围会被归一化到(0,1)之间
    18. transforms.Normalize(mean=0.5,std=0.5) #设置均值和方差均为0.5
    19. ])
    20. #加载数据集
    21. dataset = torchvision.datasets.MNIST('data',
    22. train=True,
    23. transform=transform,
    24. target_transform = one_hot,
    25. download = True)
    26. dl = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle = True)
    27. #定义生成器
    28. class Generator(nn.Module):
    29. def __init__(self):
    30. super(Generator,self).__init__()
    31. self.linear1 = nn.Linear(100,128*7*7)
    32. self.bn1=nn.BatchNorm1d(128*7*7)
    33. self.linear2 = nn.Linear(10,128*7*7)
    34. self.bn2=nn.BatchNorm1d(128*7*7)
    35. self.deconv1 = nn.ConvTranspose2d(256,128,
    36. kernel_size=(3,3),
    37. stride=1,
    38. padding=1) #生成(128,7,7)的二维图像
    39. self.bn3=nn.BatchNorm2d(128)
    40. self.deconv2 = nn.ConvTranspose2d(128,64,
    41. kernel_size=(4,4),
    42. stride=2,
    43. padding=1) #生成(64,14,14)的二维图像
    44. self.bn4=nn.BatchNorm2d(64)
    45. self.deconv3 = nn.ConvTranspose2d(64,1,
    46. kernel_size=(4,4),
    47. stride=2,
    48. padding=1) #生成(1,28,28)的二维图像
    49. def forward(self,x1,x2):
    50. x1=F.relu(self.linear1(x1))
    51. x1=self.bn1(x1)
    52. x1=x1.view(-1,128,7,7)
    53. x2=F.relu(self.linear2(x2))
    54. x2=self.bn2(x2)
    55. x2=x2.view(-1,128,7,7)
    56. x=torch.cat([x1,x2],axis=1) #batch, 256, 7, 7
    57. x=x.view(-1,256,7,7)
    58. x=F.relu(self.deconv1(x))
    59. x=self.bn3(x)
    60. x=F.relu(self.deconv2(x))
    61. x=self.bn4(x)
    62. x=torch.tanh(self.deconv3(x))
    63. return x
    64. #定义判别器
    65. #输入:1,28,28图片和长度为10的condition
    66. class Discriminator(nn.Module):
    67. def __init__(self):
    68. super(Discriminator,self).__init__()
    69. self.linear = nn.Linear(10,1*28*28)
    70. self.conv1 = nn.Conv2d(2,64,kernel_size=3,stride=2)
    71. self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2)
    72. self.bn = nn.BatchNorm2d(128)
    73. self.fc = nn.Linear(128*6*6,1)
    74. def forward(self,x1,x2): #x1代表label,x2代表image
    75. x1=F.leaky_relu(self.linear(x1))
    76. x1=x1.view(-1,1,28,28)
    77. x=torch.cat([x1,x2],axis=1) #shape:batch,2,28,28
    78. x= F.dropout2d(F.leaky_relu(self.conv1(x)))
    79. x= F.dropout2d(F.leaky_relu(self.conv2(x)) ) #(batch,128,6,6)
    80. x = self.bn(x)
    81. x = x.view(-1,128*6*6) #展平
    82. x = torch.sigmoid(self.fc(x))
    83. return x
    84. #模型训练
    85. #设备的配置
    86. device='cuda' if torch.cuda.is_available() else 'cpu'
    87. #初化生成器和判别器把他们放到相应的设备上
    88. gen = Generator().to(device)
    89. dis = Discriminator().to(device)
    90. #交叉熵损失函数
    91. loss_fn = torch.nn.BCELoss()
    92. #训练器的优化器
    93. d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5)
    94. #训练生成器的优化器
    95. g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-4)
    96. #定义可视化函数
    97. def generate_and_save_images(model,epoch,label_input,noise_input):
    98. prediction = np.squeeze(model(noise_input,label_input).cpu().numpy())
    99. fig = plt.figure(figsize=(4,4))
    100. for i in range(prediction.shape[0]):
    101. plt.subplot(4,4,i+1)
    102. plt.imshow((prediction[i]+1)/2,cmap='gray')
    103. plt.axis('off')
    104. plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    105. plt.show()
    106. #设置生成绘图图片的随机张量,这里可视化16张图片
    107. #生成16个长度为100的随机正态分布张量
    108. noise_seed = torch.randn(16,100,device=device)
    109. label_seed = torch.randint(0,10,size=(16,))
    110. label_seed_onehot = one_hot(label_seed).to(device)
    111. D_loss = [] #记录训练过程中判别器的损失
    112. G_loss = [] #记录训练过程中生成器的损失
    113. #训练循环
    114. for epoch in range(10):
    115. #初始化损失值
    116. D_epoch_loss = 0
    117. G_epoch_loss = 0
    118. count = len(dl.dataset) #返回批次数
    119. #对数据集进行迭代
    120. for step,(img,label) in enumerate(dl):
    121. img =img.to(device) #把数据放到设备上
    122. label = label.to(device)
    123. size = img.shape[0] #img的第一位是size,获取批次的大小
    124. random_seed = torch.randn(size,100,device=device)
    125. #判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化
    126. d_optimizer.zero_grad()#梯度归零
    127. #判别器对于真实图片产生的损失
    128. real_output = dis(label,img) #判别器输入真实的图片,real_output对真实图片的预测结果
    129. d_real_loss = loss_fn(real_output,
    130. torch.ones_like(real_output,device=device)
    131. )
    132. d_real_loss.backward()#计算梯度
    133. #在生成器上去计算生成器的损失,优化目标是判别器上的参数
    134. generated_img = gen(random_seed,label) #得到生成的图片
    135. #因为优化目标是判别器,所以对生成器上的优化目标进行截断
    136. fake_output = dis(label,generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了
    137. #判别器在生成图像上产生的损失
    138. d_fake_loss = loss_fn(fake_output,
    139. torch.zeros_like(fake_output,device=device)
    140. )
    141. d_fake_loss.backward()
    142. #判别器损失
    143. disc_loss = d_real_loss + d_fake_loss
    144. #判别器优化
    145. d_optimizer.step()
    146. #生成器上损失的构建和优化
    147. g_optimizer.zero_grad() #先将生成器上的梯度置零
    148. fake_output = dis(label,generated_img)
    149. gen_loss = loss_fn(fake_output,
    150. torch.ones_like(fake_output,device=device)
    151. ) #生成器损失
    152. gen_loss.backward()
    153. g_optimizer.step()
    154. #累计每一个批次的loss
    155. with torch.no_grad():
    156. D_epoch_loss +=disc_loss
    157. G_epoch_loss +=gen_loss
    158. #求平均损失
    159. with torch.no_grad():
    160. D_epoch_loss /=count
    161. G_epoch_loss /=count
    162. D_loss.append(D_epoch_loss)
    163. G_loss.append(G_epoch_loss)
    164. #训练完一个Epoch,打印提示并绘制生成的图片
    165. print("Epoch:",epoch)
    166. print(label_seed)
    167. generate_and_save_images(gen,epoch,label_seed_onehot,noise_seed)

    5.运行结果

    因篇幅有限,只展示一部分运行结果

     

     

     

     

     

     

     

    6.CGAN缺陷

    CGAN生成的图像虽然有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像网络时对属性特征的处理方法均受到CGAN启发。


    希望我的文章能对你有所帮助。欢迎👍点赞 ,📝评论,🌟关注,⭐️收藏

                                                                        

     

  • 相关阅读:
    一个开源的音频分离深度学习项目
    购买发票自动化软件(或者文档管理系统)需要注意的问题
    金融行业分布式数据库选型及实践经验
    poi-tl 用word模板生成报告
    AI写作宝-为什么要使用写作宝
    【云原生 | 从零开始学Kubernetes】二、使用kubeadm搭建K8S集群
    docker compose
    IBT机考-PBT笔考,优劣分析,柯桥口语学习,韩语入门,topik考级韩语
    软考中级哪一门比较好过?
    数据结构————树
  • 原文地址:https://blog.csdn.net/weixin_51781852/article/details/126203331