• 【深度学习】pix2pix GAN理论及代码实现


    目录

    1.什么是pix2pix GAN

    2.pix2pixGAN生成器的设计

    3.pix2pixGAN判别器的设计

    4.损失函数

    5.代码实现 


    1.什么是pix2pix GAN

    Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件,需要输入到G和D中。G的输入是x(x是需要转换的图片),输出是生成的图片G(x)。D则需要分辨出(x,G(x))和(x,y)

    pix2pixGAN主要用于图像之间的转换,又称图像翻译。

    2.pix2pixGAN生成器的设计

    对于图像翻译任务来说,输入和输出之间会共享很多信息。比如轮廓信息是共享的。如何解决共享问题?需要我们从损失函数的设计当中去思考。

    如果使用普通的卷积神经网络,那么会导致每一层都承载保存着所有的信息。这样神经网络很容易出错(容易丢失一些信息)

    所以,我们使用UNet模型作为生成器

    3.pix2pixGAN判别器的设计

    D要输入成对的图像。这类似于cGAN,如果G(x)和x是对应的,对于生成器来说希望判别为1;

    如果G(x)和x不是对应的,对于生成器来说希望判别器判别为0

    pix2pixGAN中的D被论文中被实现为patch_D.所谓patch,是指无论生成的图片有多大,将其切分为多个固定大小的patch输入进D去判断。如上图所示。

    这样设计的好处是:D的输入变小,计算量小,训练速度快

    4.损失函数

    D网络损失函数:输入真实的成对图像希望判定为1;输入生成图像与原图希望判定为0

    G网络损失函数:输入生成图像与原图像希望判定为1

    公式如下图所示: 

    对于图像翻译任务而言,G的输入和输出之间其实共享了很多信息。因而为了保证输入图像和输出图像之间的相似度,还加入了L1loss,公式如下所示:

     所以,结合两个公式,总的损失函数为

    5.代码实现 

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. from torch.utils import data
    5. import torchvision #加载图片
    6. from torchvision import transforms #图片变换
    7. import numpy as np
    8. import matplotlib.pyplot as plt #绘图
    9. import os
    10. import glob
    11. from PIL import Image
    12. imgs_path = glob.glob('base/*.jpg')
    13. annos_path = glob.glob('base/*.png')
    14. #预处理
    15. transform = transforms.Compose([
    16. transforms.ToTensor(),
    17. transforms.Resize((256,256)),
    18. transforms.Normalize(mean=0.5,std=0.5
    19. )
    20. ])
    21. #定义数据集
    22. class CMP_dataset(data.Dataset):
    23. def __init__(self,imgs_path,annos_path):
    24. self.imgs_path =imgs_path
    25. self.annos_path = annos_path
    26. def __getitem__(self,index):
    27. img_path = self.imgs_path[index]
    28. anno_path = self.annos_path[index]
    29. pil_img = Image.open(img_path) #读取数据
    30. pil_img = transform(pil_img) #转换数据
    31. anno_img = Image.open(anno_path) #读取数据
    32. anno_img = anno_img.convert("RGB")
    33. pil_anno = transform(anno_img) #转换数据
    34. return pil_anno,pil_img
    35. def __len__(self):
    36. return len(self.imgs_path)
    37. #创建数据集
    38. dataset = CMP_dataset(imgs_path,annos_path)
    39. #将数据转化为dataloader的格式,方便迭代
    40. BATCHSIZE = 32
    41. dataloader = data.DataLoader(dataset,
    42. batch_size = BATCHSIZE,
    43. shuffle = True)
    44. annos_batch,imgs_batch = next(iter(dataloader))
    45. #定义下采样模块
    46. class Downsample(nn.Module):
    47. def __init__(self,in_channels,out_channels):
    48. super(Downsample,self).__init__()
    49. self.conv_relu = nn.Sequential(
    50. nn.Conv2d(in_channels,out_channels,
    51. kernel_size=3,
    52. stride=2,
    53. padding=1),
    54. nn.LeakyReLU(inplace=True))
    55. self.bn = nn.BatchNorm2d(out_channels)
    56. def forward(self,x,is_bn=True):
    57. x=self.conv_relu(x)
    58. if is_bn:
    59. x=self.bn(x)
    60. return x
    61. #定义上采样模块
    62. class Upsample(nn.Module):
    63. def __init__(self,in_channels,out_channels):
    64. super(Upsample,self).__init__()
    65. self.upconv_relu = nn.Sequential(
    66. nn.ConvTranspose2d(in_channels,out_channels,
    67. kernel_size=3,
    68. stride=2,
    69. padding=1,
    70. output_padding=1), #反卷积,变为原来的2倍
    71. nn.LeakyReLU(inplace=True))
    72. self.bn = nn.BatchNorm2d(out_channels)
    73. def forward(self,x,is_drop=False):
    74. x=self.upconv_relu(x)
    75. x=self.bn(x)
    76. if is_drop:
    77. x=F.dropout2d(x)
    78. return x
    79. #定义生成器:包含6个下采样,5个上采样,一个输出层
    80. class Generator(nn.Module):
    81. def __init__(self):
    82. super(Generator,self).__init__()
    83. self.down1 = Downsample(3,64) #64,128,128
    84. self.down2 = Downsample(64,128) #128,64,64
    85. self.down3 = Downsample(128,256) #256,32,32
    86. self.down4 = Downsample(256,512) #512,16,16
    87. self.down5 = Downsample(512,512) #512,8,8
    88. self.down6 = Downsample(512,512) #512,4,4
    89. self.up1 = Upsample(512,512) #512,8,8
    90. self.up2 = Upsample(1024,512) #512,16,16
    91. self.up3 = Upsample(1024,256) #256,32,32
    92. self.up4 = Upsample(512,128) #128,64,64
    93. self.up5 = Upsample(256,64) #64,128,128
    94. self.last = nn.ConvTranspose2d(128,3,
    95. kernel_size=3,
    96. stride=2,
    97. padding=1,
    98. output_padding=1) #3,256,256
    99. def forward(self,x):
    100. x1 = self.down1(x)
    101. x2 = self.down2(x1)
    102. x3 = self.down3(x2)
    103. x4 = self.down4(x3)
    104. x5 = self.down5(x4)
    105. x6 = self.down6(x5)
    106. x6 = self.up1(x6,is_drop=True)
    107. x6 = torch.cat([x6,x5],dim=1)
    108. x6 = self.up2(x6,is_drop=True)
    109. x6 = torch.cat([x6,x4],dim=1)
    110. x6 = self.up3(x6,is_drop=True)
    111. x6 = torch.cat([x6,x3],dim=1)
    112. x6 = self.up4(x6)
    113. x6 = torch.cat([x6,x2],dim=1)
    114. x6 = self.up5(x6)
    115. x6 = torch.cat([x6,x1],dim=1)
    116. x6 = torch.tanh(self.last(x6))
    117. return x6
    118. #定义判别器 输入anno+img(生成或者真实) concat
    119. class Discriminator(nn.Module):
    120. def __init__(self):
    121. super(Discriminator,self).__init__()
    122. self.down1 = Downsample(6,64)
    123. self.down2 = Downsample(64,128)
    124. self.conv1 = nn.Conv2d(128,256,3)
    125. self.bn = nn.BatchNorm2d(256)
    126. self.last = nn.Conv2d(256,1,3)
    127. def forward(self,anno,img):
    128. x=torch.cat([anno,img],axis =1)
    129. x=self.down1(x,is_bn=False)
    130. x=self.down2(x,is_bn=True)
    131. x=F.dropout2d(self.bn(F.leaky_relu(self.conv1(x))))
    132. x=torch.sigmoid(self.last(x)) #batch*1*60*60
    133. return x
    134. device = "cuda" if torch.cuda.is_available() else'cpu'
    135. gen = Generator().to(device)
    136. dis = Discriminator().to(device)
    137. d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-3,betas=(0.5,0.999))
    138. g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-3,betas=(0.5,0.999))
    139. #绘图
    140. def generate_images(model,test_anno,test_real):
    141. prediction = model(test_anno).permute(0,2,3,1).detach().cpu().numpy()
    142. test_anno = test_anno.permute(0,2,3,1).cpu().numpy()
    143. test_real = test_real.permute(0,2,3,1).cpu().numpy()
    144. plt.figure(figsize = (10,10))
    145. display_list = [test_anno[0],test_real[0],prediction[0]]
    146. title = ['Input','Ground Truth','Output']
    147. for i in range(3):
    148. plt.subplot(1,3,i+1)
    149. plt.title(title[i])
    150. plt.imshow(display_list[i])
    151. plt.axis('off') #坐标系关掉
    152. plt.show()
    153. test_imgs_path = glob.glob('extended/*.jpg')
    154. test_annos_path = glob.glob('extended/*.png')
    155. test_dataset = CMP_dataset(test_imgs_path,test_annos_path)
    156. test_dataloader = torch.utils.data.DataLoader(
    157. test_dataset,
    158. batch_size=BATCHSIZE,)
    159. #定义损失函数
    160. #cgan 损失函数
    161. loss_fn = torch.nn.BCELoss()
    162. #L1 loss
    163. annos_batch,imgs_batch = annos_batch.to(device),imgs_batch.to(device)
    164. LAMBDA = 7 #L1损失的权重
    165. D_loss = []#记录训练过程中判别器loss变化
    166. G_loss = []#记录训练过程中生成器loss变化
    167. #开始训练
    168. for epoch in range(10):
    169. D_epoch_loss = 0
    170. G_epoch_loss = 0
    171. count = len(dataloader)
    172. for step,(annos,imgs) in enumerate(dataloader):
    173. imgs = imgs.to(device)
    174. annos = annos.to(device)
    175. #定义判别器的损失计算以及优化的过程
    176. d_optimizer.zero_grad()
    177. disc_real_output = dis(annos,imgs)#输入真实成对图片
    178. d_real_loss = loss_fn(disc_real_output,torch.ones_like(disc_real_output,
    179. device=device))
    180. d_real_loss.backward()
    181. gen_output = gen(annos)
    182. disc_gen_output = dis(annos,gen_output.detach())
    183. d_fack_loss = loss_fn(disc_gen_output,torch.zeros_like(disc_gen_output,
    184. device=device))
    185. d_fack_loss.backward()
    186. disc_loss = d_real_loss+d_fack_loss#判别器的损失计算
    187. d_optimizer.step()
    188. #定义生成器的损失计算以及优化的过程
    189. g_optimizer.zero_grad()
    190. disc_gen_out = dis(annos,gen_output)
    191. gen_loss_crossentropyloss = loss_fn(disc_gen_out,
    192. torch.ones_like(disc_gen_out,
    193. device=device))
    194. gen_l1_loss = torch.mean(torch.abs(gen_output-imgs))
    195. gen_loss = gen_loss_crossentropyloss +LAMBDA*gen_l1_loss
    196. gen_loss.backward() #反向传播
    197. g_optimizer.step() #优化
    198. #累计每一个批次的loss
    199. with torch.no_grad():
    200. D_epoch_loss +=disc_loss.item()
    201. G_epoch_loss +=gen_loss.item()
    202. #求平均损失
    203. with torch.no_grad():
    204. D_epoch_loss /=count
    205. G_epoch_loss /=count
    206. D_loss.append(D_epoch_loss)
    207. G_loss.append(G_epoch_loss)
    208. #训练完一个Epoch,打印提示并绘制生成的图片
    209. print("Epoch:",epoch)
    210. generate_images(gen,annos_batch,imgs_batch)
  • 相关阅读:
    uniapp使用scroll-into-view实现锚点定位和滚动监听功能【楼层效果 / 侧边导航联动效果】
    Linux - nm命令
    【Python3】【力扣题】242. 有效的字母异位词
    135. 分发糖果
    顺应趋势,用大数据精准营销抓住大数据时代的机遇
    代码随想录怎么样?我是这样刷的
    C语言:二维数组封装成函数(全局变量)
    2022上行部落的学习和实践总结
    postgresql安装配置和基本操作
    【JavaScript】懒加载
  • 原文地址:https://blog.csdn.net/weixin_51781852/article/details/126238330