• 好好学习第四天:生成对抗网络(GAN)手写数字生成


    >- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/k-vYaC8l7uxX51WoypLkTw) 中的学习记录博客**
    >- **🍦 参考文章地址: [🔗深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天](https://mtyjkh.blog.csdn.net/article/details/118995896)**
    >- **🍖 作者:[K同学啊](https://mp.weixin.qq.com/s/k-vYaC8l7uxX51WoypLkTw)**

    活动地址:CSDN21天学习挑战赛

    一、GAN的原理

    参考论文:Generative Adversarial Nets

    主要内容:提出了一个通过对抗性过程来估计生成模型的新框架,其中我们同时训练两个模型:一个是捕获数据分布的生成模型G,另一个是估计样本来自训练数据而不是G的概率的判别模型D。G的训练过程是最大化D犯错的概率。这个框架对应于一个极大极小的双人博弈。在任意函数G和D的空间中,存在一个唯一解,G恢复训练数据的分布,且D处处等于12。在G和D由多层感知器定义的情况下,整个系统可以用反向传播进行训练。在样本的训练和生成过程中,不需要任何马尔可夫链或展开的近似推理网络

    • 生成模型G :模型学习的是联合概率分布P(X,Y),任务是得到属性为X且类别为Y时的联合概率
    • 判别模型D: 模型学习的是条件概率分布P(Y|X),任务是从属性X(特征)预测Y(类别)
    • minimax:为了使己方达到最优解,所以把目标设为让对方的最大收益最小化
    • 在提出的对抗性网络框架中,生成模型是针对对手的:一个判别模型,它学习确定样本是来自模型分布还是数据分布。生成模型可以被认为类似于一个伪造团队,试图生产假币并使用它而不被发现,而判别模型类似于警察,试图发现假币。这场比赛的竞争促使两支队伍改进他们的方法,直到仿冒品与真品难以区分。

     

    二、实验过程

    用生成器生成手写数字图像,用鉴别器鉴别图像的真假。两者相互对抗学习,在对抗学习过程中不断完善自己,直至生成器可以生成以假乱真的图片。

    1.导入库

    1. from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
    2. from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
    3. import matplotlib.pyplot as plt
    4. import numpy as np
    5. import sys,os,pathlib

     2.定义训练参数

    1. img_shape = (28, 28, 1)
    2. latent_dim = 200

    3.构建生成器

    生成器接收随机数并返回生成图像。

    将生成的数字图像与实际数据集中的数字图像一起送到鉴别器。

    1. def build_generator():
    2. # ======================================= #
    3. # 生成器,输入一串随机数字生成图片
    4. # ======================================= #
    5. model = Sequential([
    6. layers.Dense(256, input_dim=latent_dim),
    7. layers.LeakyReLU(alpha=0.2), # 高级一点的激活函数
    8. layers.BatchNormalization(momentum=0.8), # BN 归一化
    9. layers.Dense(512),
    10. layers.LeakyReLU(alpha=0.2),
    11. layers.BatchNormalization(momentum=0.8),
    12. layers.Dense(1024),
    13. layers.LeakyReLU(alpha=0.2),
    14. layers.BatchNormalization(momentum=0.8),
    15. layers.Dense(np.prod(img_shape), activation='tanh'),
    16. layers.Reshape(img_shape)
    17. ])
    18. noise = layers.Input(shape=(latent_dim,))
    19. img = model(noise)
    20. return Model(noise, img)

    4.构建鉴别器

    鉴别器接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

    1. def build_discriminator():
    2. # ===================================== #
    3. # 鉴别器,对输入的图片进行判别真假
    4. # ===================================== #
    5. model = Sequential([
    6. layers.Flatten(input_shape=img_shape),
    7. layers.Dense(512),
    8. layers.LeakyReLU(alpha=0.2),
    9. layers.Dense(256),
    10. layers.LeakyReLU(alpha=0.2),
    11. layers.Dense(1, activation='sigmoid')
    12. ])
    13. img = layers.Input(shape=img_shape)
    14. validity = model(img)
    15. return Model(img, validity)

    5.不断训练

    import tensorflow as tf
    1. # 创建判别器
    2. discriminator = build_discriminator()
    3. # 定义优化器
    4. optimizer = tf.keras.optimizers.Adam(1e-4)
    5. discriminator.compile(loss='binary_crossentropy',
    6. optimizer=optimizer,
    7. metrics=['accuracy'])
    8. # 创建生成器
    9. generator = build_generator()
    10. gan_input = layers.Input(shape=(latent_dim,))
    11. img = generator(gan_input)
    12. # 在训练generate的时候不训练discriminator
    13. discriminator.trainable = False
    14. # 对生成的假图片进行预测
    15. validity = discriminator(img)
    16. combined = Model(gan_input, validity)
    17. combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    6.保存样例图片

    1. def sample_images(epoch):
    2. """
    3. 保存样例图片
    4. """
    5. row, col = 4, 4
    6. noise = np.random.normal(0, 1, (row*col, latent_dim))
    7. gen_imgs = generator.predict(noise)
    8. fig, axs = plt.subplots(row, col)
    9. cnt = 0
    10. for i in range(row):
    11. for j in range(col):
    12. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
    13. axs[i,j].axis('off')
    14. cnt += 1
    15. fig.savefig("images/%05d.png" % epoch)
    16. plt.close()

     7.训练模型

    1. def train(epochs, batch_size=128, sample_interval=50):
    2. # 加载数据
    3. (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()
    4. # 将图片标准化到 [-1, 1] 区间内
    5. train_images = (train_images - 127.5) / 127.5
    6. # 数据
    7. train_images = np.expand_dims(train_images, axis=3)
    8. # 创建标签
    9. true = np.ones((batch_size, 1))
    10. fake = np.zeros((batch_size, 1))
    11. # 进行循环训练
    12. for epoch in range(epochs):
    13. # 随机选择 batch_size 张图片
    14. idx = np.random.randint(0, train_images.shape[0], batch_size)
    15. imgs = train_images[idx]
    16. # 生成噪音
    17. noise = np.random.normal(0, 1, (batch_size, latent_dim))
    18. # 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)
    19. gen_imgs = generator.predict(noise)
    20. # 训练鉴别器
    21. d_loss_true = discriminator.train_on_batch(imgs, true)
    22. d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    23. # 返回loss值
    24. d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
    25. # 训练生成器
    26. noise = np.random.normal(0, 1, (batch_size, latent_dim))
    27. g_loss = combined.train_on_batch(noise, true)
    28. print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    29. # 保存样例图片
    30. if epoch % sample_interval == 0:
    31. sample_images(epoch)
    train(epochs=1200, batch_size=256, sample_interval=200)
    

     

     8.生成动图

    1. import imageio
    2. def compose_gif():
    3. # 图片地址
    4. data_dir = "images"
    5. data_dir = pathlib.Path(data_dir)
    6. paths = list(data_dir.glob('*'))
    7. gif_images = []
    8. for path in paths:
    9. print(path)
    10. gif_images.append(imageio.imread(path))
    11. imageio.mimsave("test.gif",gif_images,fps=2)
    12. compose_gif()

    gif

    9.生成图片展示

    1. import os
    2. import numpy as np
    3. import matplotlib.pyplot as plt
    4. import random
    5. directory = "images/"
    6. images = random.choices(os.listdir(directory), k=6)
    7. fig = plt.figure(figsize=(10, 10))
    8. columns = 3
    9. rows = 2
    10. for x, i in enumerate(images):
    11. path = os.path.join(directory,i)
    12. img = plt.imread(path)
    13. fig.add_subplot(rows, columns, x+1)
    14. plt.imshow(img)
    15. plt.show()

     

  • 相关阅读:
    IDL学习——哨兵2 L1C数据辐射定标
    研究报告:周界警戒AI算法+视频智能分析在安全生产场景中的应用
    【C++】vector的模拟实现【完整版】
    Java项目中排查JVM问题的思路
    好心情:6种会加重抑郁症的食物,你却每天都在吃
    Vue.js 的事件循环(Event Loop)机制
    猿创征文|【深度学习前沿应用】文本审核
    目标检测:Anchor-free算法模型
    npm ERR! exited with error code: 128
    已经有 MESI 协议,为什么还需要 volatile 关键字?
  • 原文地址:https://blog.csdn.net/liuyingshudian/article/details/126351891