• 对抗网络(GAN)手写数字生成



    活动地址:CSDN21天学习挑战赛

     


    目录

    1.跑通代码

     2.代码分析

    2.1

    2.2

    2.3

    2.4

    2.5



    (2条消息) tensorflow零基础入门学习_重邮研究森的博客-CSDN博客_tensorflow 学习icon-default.png?t=M666https://blog.csdn.net/m0_60524373/article/details/124143223https://blog.csdn.net/m0_60524373/article/details/124143223​>- 本文为[365天深度学习训练营](https://mp.weixin.qq.com/s/k-vYaC8l7uxX51WoypLkTw) 中的学习记录博客
    >- 参考文章地址: (1条消息) 深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天_K同学啊的博客-CSDN博客icon-default.png?t=M666https://mtyjkh.blog.csdn.net/article/details/118995896


    本文开发环境:tensorflowgpu2.5,经过验证,2.4也可以运行


    1.跑通代码

    我这个人对于任何代码,我都会先去跑通之和才会去观看内容,哈哈哈,所以第一步我们先不管37=21,直接把博主的代码复制黏贴一份运行结果。(PS:做了一些修改,因为原文是jupyter,而我在pycharm)

    1. import tensorflow as tf
    2. gpus = tf.config.list_physical_devices("GPU")
    3. if gpus:
    4. tf.config.experimental.set_memory_growth(gpus[0], True) # 设置GPU显存用量按需使用
    5. tf.config.set_visible_devices([gpus[0]], "GPU")
    6. # 打印显卡信息,确认GPU可用
    7. print(gpus)
    8. from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
    9. from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
    10. import matplotlib.pyplot as plt
    11. import numpy as np
    12. import sys,os,pathlib
    13. img_shape = (28, 28, 1)
    14. latent_dim = 200
    15. def build_generator():
    16. # ======================================= #
    17. # 生成器,输入一串随机数字生成图片
    18. # ======================================= #
    19. model = Sequential([
    20. layers.Dense(256, input_dim=latent_dim),
    21. layers.LeakyReLU(alpha=0.2), # 高级一点的激活函数
    22. layers.BatchNormalization(momentum=0.8), # BN 归一化
    23. layers.Dense(512),
    24. layers.LeakyReLU(alpha=0.2),
    25. layers.BatchNormalization(momentum=0.8),
    26. layers.Dense(1024),
    27. layers.LeakyReLU(alpha=0.2),
    28. layers.BatchNormalization(momentum=0.8),
    29. layers.Dense(np.prod(img_shape), activation='tanh'),
    30. layers.Reshape(img_shape)
    31. ])
    32. noise = layers.Input(shape=(latent_dim,))
    33. img = model(noise)
    34. return Model(noise, img)
    35. def build_discriminator():
    36. # ===================================== #
    37. # 鉴别器,对输入的图片进行判别真假
    38. # ===================================== #
    39. model = Sequential([
    40. layers.Flatten(input_shape=img_shape),
    41. layers.Dense(512),
    42. layers.LeakyReLU(alpha=0.2),
    43. layers.Dense(256),
    44. layers.LeakyReLU(alpha=0.2),
    45. layers.Dense(1, activation='sigmoid')
    46. ])
    47. img = layers.Input(shape=img_shape)
    48. validity = model(img)
    49. return Model(img, validity)
    50. # 创建判别器
    51. discriminator = build_discriminator()
    52. # 定义优化器
    53. optimizer = tf.keras.optimizers.Adam(1e-4)
    54. discriminator.compile(loss='binary_crossentropy',
    55. optimizer=optimizer,
    56. metrics=['accuracy'])
    57. # 创建生成器
    58. generator = build_generator()
    59. gan_input = layers.Input(shape=(latent_dim,))
    60. img = generator(gan_input)
    61. # 在训练generate的时候不训练discriminator
    62. discriminator.trainable = False
    63. # 对生成的假图片进行预测
    64. validity = discriminator(img)
    65. combined = Model(gan_input, validity)
    66. combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    67. def sample_images(epoch):
    68. """
    69. 保存样例图片
    70. """
    71. row, col = 4, 4
    72. noise = np.random.normal(0, 1, (row*col, latent_dim))
    73. gen_imgs = generator.predict(noise)
    74. fig, axs = plt.subplots(row, col)
    75. cnt = 0
    76. for i in range(row):
    77. for j in range(col):
    78. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
    79. axs[i,j].axis('off')
    80. cnt += 1
    81. fig.savefig("images/%05d.png" % epoch)
    82. # fig.savefig(" E:/2021_Project_YanYiXia/AI/21/对抗网络(GAN)手写数字生成/images/%05d.png" % epoch)
    83. plt.close()
    84. def train(epochs, batch_size=128, sample_interval=50):
    85. # 加载数据
    86. (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
    87. # 将图片标准化到 [-1, 1] 区间内
    88. train_images = (train_images - 127.5) / 127.5
    89. # 数据
    90. train_images = np.expand_dims(train_images, axis=3)
    91. # 创建标签
    92. true = np.ones((batch_size, 1))
    93. fake = np.zeros((batch_size, 1))
    94. # 进行循环训练
    95. for epoch in range(epochs):
    96. # 随机选择 batch_size 张图片
    97. idx = np.random.randint(0, train_images.shape[0], batch_size)
    98. imgs = train_images[idx]
    99. # 生成噪音
    100. noise = np.random.normal(0, 1, (batch_size, latent_dim))
    101. # 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)
    102. gen_imgs = generator.predict(noise)
    103. # 训练鉴别器
    104. d_loss_true = discriminator.train_on_batch(imgs, true)
    105. d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    106. # 返回loss值
    107. d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
    108. # 训练生成器
    109. noise = np.random.normal(0, 1, (batch_size, latent_dim))
    110. g_loss = combined.train_on_batch(noise, true)
    111. print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
    112. # 保存样例图片
    113. if epoch % sample_interval == 0:
    114. sample_images(epoch)
    115. #train(epochs=30000, batch_size=256, sample_interval=200)
    116. import imageio
    117. def compose_gif():
    118. # 图片地址
    119. data_dir = "E:/2021_Project_YanYiXia/AI/21/对抗网络(GAN)手写数字生成/images"
    120. data_dir = pathlib.Path(data_dir)
    121. paths = list(data_dir.glob('*'))
    122. gif_images = []
    123. for path in paths:
    124. print(path)
    125. gif_images.append(imageio.imread(path))
    126. imageio.mimsave("test.gif", gif_images, fps=2)
    127. compose_gif()

    点击pycharm运行即可得到结果,此图为对抗网络生成的手写数字

     2.代码分析

    神经网络的整个过程我分为如下六部分,而我们也会对这六部分进行逐部分分析。那么这6部分分别是:但是:这里是对抗网络,和传统方法有区别,因此六步法不适用了,我们将重新分析。

    五步法:

    1->import

    2->设置生成器和判别器

    3->创建生成器和判别器

    4->训练模型

    5->验证

    2.1

    导入:这里很容易理解,也就是导入本次实验内容所需要的各种库。在本案例中主要包括以下部分:

     蓝框1:

    设置电脑gpu工作,如果你的电脑没有gpu就不设置,或者你的gpu显存不够,训练时出问题了,那么就设在为cpu模式

    蓝框2:

    导入各种库

    对于这里的话我们可以直接复制黏贴,当需要一些其他函数时,只需要添加对应的库文件即可。

    2.2

    这里是设置生成器和判别器。其中对于生成器和判别器的详细解释可以参考下面这个链接。

    四天搞懂生成对抗网络(一)——通俗理解经典GAN - 知乎 (zhihu.com)icon-default.png?t=M666https://zhuanlan.zhihu.com/p/307527293总之,我们需要清楚的一点是,对抗网络中:

    生成器:根据随机数生成一些“以假乱真”的数据集

    判别器:对生成器“以假乱真”的数据集和真实的数据集进行判别

    两者在训练过程中都会不断进行优化,生成器会不断生成更多“更真”的数据,判别器会“检测”的更仔细。

    下面进行详细代码解释:

    蓝框 1:

    这里设置我们的图片格式和输入的维度

    蓝框2:

    这里引入了alpha激活函数,批标准化,prod函数。

    先设置网络层,然后把噪音当作输入层,img当作输出层。

     这部分为判别器。

    输入是之前噪音生成的img,输出是真或者假

    很有趣的一点,生成器和判别器的网络模型基本上是对折的!

    2.3

    在生成器和判别器都定义好之后,我们可以创建它们

     蓝框1:

    设置优化器,关于优化器的参数设置可以参考文章开头我之前写的一篇基础文章

    蓝框2:

    创建生成器,生成器输入为噪音维度的数,输出是图片数据

    蓝框3:

    在训练生成器的时候不训练判别器

    2.4

    在完成基础准备工作之后,就可以开始训练了

     重点!!!

    现在我们来对如何对生成器和判别器训练进行代码解读!!!

    蓝框1:

    这里就是我们之前文章调用minist数据集制作datset的方法,包括加载数据,数据处理,归一化,修改维度。同时这里的区别是:在标签方面,1为真,0为假,都是根据batch_size来生产的一个列表。

    蓝框2:

    返回一个随机整型数,范围从低0(包括)到高 train_images.shape[0](不包括).另外输出随机数的尺寸为batch_size

    总结:这里就是随机获取官方数据集中任意一组图片

    蓝框3:

    从正态(高斯)分布中抽取随机样本。其中样本尺寸为(batch_size, latent_dim)

    根据噪音的随机情况可以生成随机的一个图片数据

    蓝框4:

    discriminator.train_on_batch(imgs, true)的意思是,输入为img,输出为true,返回结果为loss

    把真实数据和假数据的loss计算出来为总loss

    蓝框5

    从正态(高斯)分布中抽取随机样本。其中样本尺寸为(batch_size, latent_dim)

    根据噪音的随机情况然后利用 combined.train_on_batch(noise, true)的意思是,输入为noise,输出为true,返回结果为loss

    蓝框6

    打印每轮的损失函数信息

    蓝框7

    执行训练

     2.5

    训练结束后,我们可以观察训练的结果

    上面是生成tif动图代码

     上面是保存样例图片代码

     

     

  • 相关阅读:
    解决安装apex报错:No module named ‘packaging‘
    洛谷 P3372 【模板】线段树 1
    vue + koa + 阿里云部署 + 宝塔:宝塔前后端部署
    模拟shell小程序
    0829|C++day7 auto、lambda、C++数据类型转换、C++标准模板库(STL)、list、文件操作
    2-(脏读,不可重复读,幻读 ,mysql5.7以后默认隔离级别)、( 什么是qps,tps,并发量,pv,uv)、(什么是接口幂等性问题,如何解决?)
    Shiro讲解(基于Springboot搭建)
    入门cv必读的10篇baseline论文
    Kubernetes CKA 模拟题解析【2022最新版】(连载002)
    SpringBoot使用spring.config.import多种方式导入配置文件
  • 原文地址:https://blog.csdn.net/m0_60524373/article/details/126341959