• [PyTorch][chapter 57][WGAN-GP 代码实现]


    前言:

     下图为WGAN 的效果图:

      绿色为真实数据的分布: 8个高斯分布

      红色: 为随机产生的数据分布,跟真实分布基本一致

    WGAN-GP

    1 判别器D: 最后一层去掉sigmoid
    2 生成器G 和判别器D: loss不取log
    3 损失函数 增加了penalty,使用Adam

     Wasserstein GAN
    1 判别器D: 最后一层去掉sigmoid
    2 生成器G 和判别器D: loss不取log
    3 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
    4 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
     


    一  简介

        1.1 模型结构

     1.2 伪代码

          

    从Wasserstein距离、对偶理论到WGAN - 科学空间|Scientific Spaces


    二  wgan.py

     主要变化:

        Generator 中 去掉了之前的logit 函数

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Thu Sep 28 11:10:19 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch import nn
    8. #生成器模型
    9. h_dim = 400
    10. class Generator(nn.Module):
    11. def __init__(self):
    12. super(Generator,self).__init__()
    13. # z: [batch,input_features]
    14. self.net = nn.Sequential(
    15. nn.Linear(2, h_dim),
    16. nn.ReLU(True),
    17. nn.Linear( h_dim, h_dim),
    18. nn.ReLU(True),
    19. nn.Linear(h_dim, h_dim),
    20. nn.ReLU(True),
    21. nn.Linear(h_dim, 2)
    22. )
    23. def forward(self, z):
    24. output = self.net(z)
    25. return output
    26. #鉴别器模型
    27. class Discriminator(nn.Module):
    28. def __init__(self):
    29. super(Discriminator,self).__init__()
    30. hDim=400
    31. # x: [batch,input_features]
    32. self.net = nn.Sequential(
    33. nn.Linear(2, hDim),
    34. nn.ReLU(True),
    35. nn.Linear(hDim, hDim),
    36. nn.ReLU(True),
    37. nn.Linear(hDim, hDim),
    38. nn.ReLU(True),
    39. nn.Linear(hDim, 1),
    40. )
    41. def forward(self, x):
    42. #x:[batch,1]
    43. output = self.net(x)
    44. out = output.view(-1)
    45. return out

    三 main.py

      主要变化:

        损失函数中增加了gradient_penalty

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Thu Sep 28 11:28:32 2023
    4. @author: chengxf2
    5. """
    6. import visdom
    7. from gan import Discriminator
    8. from gan import Generator
    9. import numpy as np
    10. import random
    11. import torch
    12. from torch import nn, optim
    13. from matplotlib import pyplot as plt
    14. from torch import autograd
    15. h_dim =400
    16. batchsz = 256
    17. viz = visdom.Visdom()
    18. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    19. def weights_init(net):
    20. if isinstance(net, nn.Linear):
    21. # net.weight.data.normal_(0.0, 0.02)
    22. nn.init.kaiming_normal_(net.weight)
    23. net.bias.data.fill_(0)
    24. def data_generator():
    25. """
    26. 8- gaussian destribution
    27. Returns
    28. -------
    29. None.
    30. """
    31. scale = 2
    32. a = np.sqrt(2.0)
    33. centers =[
    34. (1,0),
    35. (-1,0),
    36. (0,1),
    37. (0,-1),
    38. (1/a,1/a),
    39. (1/a,-1/a),
    40. (-1/a, 1/a),
    41. (-1/a,-1/a)
    42. ]
    43. centers = [(scale*x, scale*y) for x,y in centers]
    44. while True:
    45. dataset =[]
    46. for i in range(batchsz):
    47. point = np.random.randn(2)*0.02
    48. center = random.choice(centers)
    49. point[0] += center[0]
    50. point[1] += center[1]
    51. dataset.append(point)
    52. dataset = np.array(dataset).astype(np.float32)
    53. dataset /=a
    54. #生成器函数是一个特殊的函数,可以返回一个迭代器
    55. yield dataset
    56. def generate_image(D, G, xr, epoch): #xr表示真实的sample
    57. """
    58. Generates and saves a plot of the true distribution, the generator, and the
    59. critic.
    60. """
    61. N_POINTS = 128
    62. RANGE = 3
    63. plt.clf()
    64. points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    65. points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    66. points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    67. points = points.reshape((-1, 2)) # (16384, 2)
    68. x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    69. N = len(x)
    70. # draw contour
    71. with torch.no_grad():
    72. points = torch.Tensor(points) # [16384, 2]
    73. disc_map = D(points).cpu().numpy() # [16384]
    74. plt.contour(x, y, disc_map.reshape((N, N)).transpose())
    75. #plt.clabel(cs, inline=1, fontsize=10)
    76. plt.colorbar()
    77. # draw samples
    78. with torch.no_grad():
    79. z = torch.randn(batchsz, 2) # [b, 2]
    80. samples = G(z).cpu().numpy() # [b, 2]
    81. plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')
    82. plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')
    83. viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    84. def gradient_penalty(D, xr,xf):
    85. #[b,1]
    86. t = torch.rand(batchsz, 1).to(device)
    87. #[b,1]=>[b,2] 保证每个sample t 相同
    88. t = t.expand_as(xr)
    89. #sample penalty interpoation [b,2]
    90. mid = t*xr +(1-t)*xf
    91. mid.requires_grad_()
    92. pred = D(mid) #[256]
    93. '''
    94. grad_outputs: 如果outputs 是向量,则此参数必须写
    95. retain_graph: True 则保留计算图, False则释放计算图
    96. create_graph: 若要计算高阶导数,则必须选为True
    97. allow_unused: 允许输入变量不进入计算
    98. '''
    99. grads = autograd.grad(outputs= pred, inputs = mid,
    100. grad_outputs= torch.ones_like(pred),
    101. create_graph=True,
    102. retain_graph=True,
    103. only_inputs=True)[0]
    104. gp = torch.pow(grads.norm(2, dim=1)-1,2).mean()
    105. return gp
    106. def main():
    107. lambd = 0.2 #超参数
    108. maxIter = 1000
    109. torch.manual_seed(10)
    110. np.random.seed(10)
    111. data_iter = data_generator()
    112. G = Generator().to(device)
    113. D = Discriminator().to(device)
    114. G.apply(weights_init)
    115. D.apply(weights_init)
    116. optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))
    117. optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))
    118. K = 5
    119. viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))
    120. for epoch in range(maxIter):
    121. #1: train Discrimator fistly
    122. for k in range(K):
    123. #1.1: train on real data
    124. xr = next(data_iter)
    125. xr = torch.from_numpy(xr).to(device)
    126. predr = D(xr)
    127. #max(predr) == min(-predr)
    128. lossr = -predr.mean()
    129. #1.2: train on fake data
    130. z = torch.randn(batchsz,2).to(device) #[b,2] 随机产生的噪声
    131. xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()
    132. predf =D(xf)
    133. lossf = predf.mean()
    134. #1.3 gradient_penalty
    135. gp = gradient_penalty(D, xr,xf.detach())
    136. #aggregate all
    137. loss_D = lossr + lossf +lambd*gp
    138. optim_D.zero_grad()
    139. loss_D.backward()
    140. optim_D.step()
    141. #print("\n Discriminator 训练结束 ",loss_D.item())
    142. # 2 train Generator
    143. #2.1 train on fake data
    144. z = torch.randn(batchsz, 2).to(device)
    145. xf = G(z)
    146. predf =D(xf) #期望最大
    147. loss_G= -predf.mean()
    148. #optimize
    149. optim_G.zero_grad()
    150. loss_G.backward()
    151. optim_G.step()
    152. if epoch %100 ==0:
    153. viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
    154. generate_image(D, G, xr, epoch)
    155. print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())
    156. if __name__ == "__main__":
    157. main()

    参考:

    课时130 WGAN-GP实战_哔哩哔哩_bilibili

    WGAN基本原理及Pytorch实现WGAN-CSDN博客

    CSDN

  • 相关阅读:
    【云原生之Docker实战】使用Docker部署jpress开源网站
    你不知道的库:库的种类,作用和加载方式
    UI设计 ,我只推荐这6个网站,真的太好用了。
    用go封装一下二级认证功能
    算法:滑动窗口
    数字统计【NOIP2010普及组】
    Android面试题汇总(五)
    【ZhangQian AI模型部署】目标检测、SAM、3D目标检测、旋转目标检测、人脸检测、检测分割、关键点、分割、深度估计、车牌识别、车道线识别
    导数的定义和介绍习题
    深入了解 Bat 脚本:简单而强大的自动化工具——基础版
  • 原文地址:https://blog.csdn.net/chengxf2/article/details/133670461