• 生成对抗网络 GAN——Generative Adversarial Nets


    核心:提出了一个新的框架通过对抗过程估计生成模型.我们同时训练了两个模型:一个生成模型G(用来捕获数据分布),一个判别模型D(用来估计采样是来自训练数据而不是生成器的概率),G的训练过程是最大化D 犯错的概率,该框架对应一个最大最小化的两人游戏。在任意函数G和D的两人空间中,存在唯一的解,当生成器G 恢复训练数据分布D 处处等于1/2。

    注意:D 的值是一个概率 即采样是来自训练数据 而不是生成器的概率

    (1)对于判别器D :1最大化把真实图片输入到判别器时候把真实图片判断为真的概率

    2 最小化 把G 生成的假图 输入到判别器中时把假图判别为真的概率 即(最大化log(1-D(G(z))

    (2)对于生成器G  目标是混淆判别器 让判别器把生成器生成的假图判别为真

    即优化目标函数:最大化D(G(z)   即最小化min (log(1-D(G(z))

    代码实现:

    1. #!/usr/bin/env python3
    2. # -*- coding: utf-8 -*-
    3. # File : test_gan.py
    4. # Author : none
    5. # Date : 14.04.2022
    6. # Last Modified Date: 15.04.2022
    7. # Last Modified By : none
    8. """ 基于MNIST 实现对抗生成网络 (GAN) """
    9. import torch
    10. import torchvision
    11. import torch.nn as nn
    12. import numpy as np
    13. image_size = [1, 28, 28]
    14. latent_dim = 96
    15. batch_size = 64
    16. use_gpu = torch.cuda.is_available()
    17. class Generator(nn.Module):
    18. def __init__(self):
    19. super(Generator, self).__init__()
    20. self.model = nn.Sequential(
    21. nn.Linear(latent_dim, 128),
    22. torch.nn.BatchNorm1d(128),
    23. torch.nn.GELU(),
    24. nn.Linear(128, 256),
    25. torch.nn.BatchNorm1d(256),
    26. torch.nn.GELU(),
    27. nn.Linear(256, 512),
    28. torch.nn.BatchNorm1d(512),
    29. torch.nn.GELU(),
    30. nn.Linear(512, 1024),
    31. torch.nn.BatchNorm1d(1024),
    32. torch.nn.GELU(),
    33. nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
    34. # nn.Tanh(),
    35. nn.Sigmoid(),
    36. )
    37. def forward(self, z):
    38. # shape of z: [batchsize, latent_dim]
    39. output = self.model(z)
    40. image = output.reshape(z.shape[0], *image_size)
    41. return image
    42. class Discriminator(nn.Module):
    43. def __init__(self):
    44. super(Discriminator, self).__init__()
    45. self.model = nn.Sequential(
    46. nn.Linear(np.prod(image_size, dtype=np.int32), 512),
    47. torch.nn.GELU(),
    48. nn.Linear(512, 256),
    49. torch.nn.GELU(),
    50. nn.Linear(256, 128),
    51. torch.nn.GELU(),
    52. nn.Linear(128, 64),
    53. torch.nn.GELU(),
    54. nn.Linear(64, 32),
    55. torch.nn.GELU(),
    56. nn.Linear(32, 1),
    57. nn.Sigmoid(),
    58. )
    59. def forward(self, image):
    60. # shape of image: [batchsize, 1, 28, 28]
    61. prob = self.model(image.reshape(image.shape[0], -1))
    62. return prob
    63. # Training
    64. dataset = torchvision.datasets.MNIST(r"D:\1APythonSpace\Use_model\gan\data\mnist", train=True, download=True,
    65. transform=torchvision.transforms.Compose(
    66. [
    67. torchvision.transforms.Resize(28),
    68. torchvision.transforms.ToTensor(),
    69. # torchvision.transforms.Normalize([0.5], [0.5]),
    70. ]
    71. )
    72. )
    73. dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    74. generator = Generator()
    75. discriminator = Discriminator()
    76. g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
    77. d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
    78. loss_fn = nn.BCELoss()
    79. labels_one = torch.ones(batch_size, 1)
    80. labels_zero = torch.zeros(batch_size, 1)
    81. if use_gpu:
    82. print("use gpu for training")
    83. generator = generator.cuda()
    84. discriminator = discriminator.cuda()
    85. loss_fn = loss_fn.cuda()
    86. labels_one = labels_one.to("cuda")
    87. labels_zero = labels_zero.to("cuda")
    88. num_epoch = 200
    89. for epoch in range(num_epoch):
    90. for i, mini_batch in enumerate(dataloader):
    91. gt_images, _ = mini_batch
    92. z = torch.randn(batch_size, latent_dim)
    93. if use_gpu:
    94. gt_images = gt_images.to("cuda")
    95. z = z.to("cuda")
    96. pred_images = generator(z)
    97. g_optimizer.zero_grad()
    98. recons_loss = torch.abs(pred_images-gt_images).mean()
    99. g_loss = recons_loss*0.05 + loss_fn(discriminator(pred_images), labels_one)
    100. g_loss.backward()
    101. g_optimizer.step()
    102. d_optimizer.zero_grad()
    103. real_loss = loss_fn(discriminator(gt_images), labels_one)
    104. fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
    105. d_loss = (real_loss + fake_loss)
    106. # 观察real_loss与fake_loss,同时下降同时达到最小值,并且差不多大,说明D已经稳定了
    107. d_loss.backward()
    108. d_optimizer.step()
    109. if i % 50 == 0:
    110. print(f"step:{len(dataloader)*epoch+i}, recons_loss:{recons_loss.item()}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")
    111. if i % 400 == 0:
    112. image = pred_images[:16].data
    113. torchvision.utils.save_image(image, f"image_{len(dataloader)*epoch+i}.png", nrow=4)

  • 相关阅读:
    谈数据库查询涉及的存储效率
    Git学习笔记
    祥云杯2022 pwn - protocol
    大数据项目 --- 电商数仓(一)
    MySQL SQL100道基础练习题
    springboot采用协同过滤算法的家政服务平台的设计与实现毕业设计源码260839
    【JavaEE】博客系统【前后端分离版本】
    Git从入门到起飞(详细)
    C# PDF转HTML字符串
    Servlet中Session会话追踪的实现机制
  • 原文地址:https://blog.csdn.net/qq_53536373/article/details/140259840