>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/k-vYaC8l7uxX51WoypLkTw) 中的学习记录博客**
>- **🍦 参考文章地址: [🔗深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天](https://mtyjkh.blog.csdn.net/article/details/118995896)**
>- **🍖 作者:[K同学啊](https://mp.weixin.qq.com/s/k-vYaC8l7uxX51WoypLkTw)**
活动地址:CSDN21天学习挑战赛
参考论文:Generative Adversarial Nets
主要内容:提出了一个通过对抗性过程来估计生成模型的新框架,其中我们同时训练两个模型:一个是捕获数据分布的生成模型G,另一个是估计样本来自训练数据而不是G的概率的判别模型D。G的训练过程是最大化D犯错的概率。这个框架对应于一个极大极小的双人博弈。在任意函数G和D的空间中,存在一个唯一解,G恢复训练数据的分布,且D处处等于12。在G和D由多层感知器定义的情况下,整个系统可以用反向传播进行训练。在样本的训练和生成过程中,不需要任何马尔可夫链或展开的近似推理网络
用生成器生成手写数字图像,用鉴别器鉴别图像的真假。两者相互对抗学习,在对抗学习过程中不断完善自己,直至生成器可以生成以假乱真的图片。
- from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
- from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
-
- import matplotlib.pyplot as plt
- import numpy as np
- import sys,os,pathlib
-
- img_shape = (28, 28, 1)
- latent_dim = 200
生成器接收随机数并返回生成图像。
将生成的数字图像与实际数据集中的数字图像一起送到鉴别器。
- def build_generator():
- # ======================================= #
- # 生成器,输入一串随机数字生成图片
- # ======================================= #
- model = Sequential([
- layers.Dense(256, input_dim=latent_dim),
- layers.LeakyReLU(alpha=0.2), # 高级一点的激活函数
- layers.BatchNormalization(momentum=0.8), # BN 归一化
-
- layers.Dense(512),
- layers.LeakyReLU(alpha=0.2),
- layers.BatchNormalization(momentum=0.8),
-
- layers.Dense(1024),
- layers.LeakyReLU(alpha=0.2),
- layers.BatchNormalization(momentum=0.8),
-
- layers.Dense(np.prod(img_shape), activation='tanh'),
- layers.Reshape(img_shape)
- ])
-
- noise = layers.Input(shape=(latent_dim,))
- img = model(noise)
-
- return Model(noise, img)
鉴别器接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。
- def build_discriminator():
- # ===================================== #
- # 鉴别器,对输入的图片进行判别真假
- # ===================================== #
- model = Sequential([
- layers.Flatten(input_shape=img_shape),
- layers.Dense(512),
- layers.LeakyReLU(alpha=0.2),
- layers.Dense(256),
- layers.LeakyReLU(alpha=0.2),
- layers.Dense(1, activation='sigmoid')
- ])
-
- img = layers.Input(shape=img_shape)
- validity = model(img)
-
- return Model(img, validity)
import tensorflow as tf
- # 创建判别器
- discriminator = build_discriminator()
- # 定义优化器
- optimizer = tf.keras.optimizers.Adam(1e-4)
- discriminator.compile(loss='binary_crossentropy',
- optimizer=optimizer,
- metrics=['accuracy'])
-
- # 创建生成器
- generator = build_generator()
- gan_input = layers.Input(shape=(latent_dim,))
- img = generator(gan_input)
-
- # 在训练generate的时候不训练discriminator
- discriminator.trainable = False
-
- # 对生成的假图片进行预测
- validity = discriminator(img)
- combined = Model(gan_input, validity)
- combined.compile(loss='binary_crossentropy', optimizer=optimizer)
- def sample_images(epoch):
- """
- 保存样例图片
- """
- row, col = 4, 4
- noise = np.random.normal(0, 1, (row*col, latent_dim))
- gen_imgs = generator.predict(noise)
-
- fig, axs = plt.subplots(row, col)
- cnt = 0
- for i in range(row):
- for j in range(col):
- axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
- axs[i,j].axis('off')
- cnt += 1
- fig.savefig("images/%05d.png" % epoch)
- plt.close()
- def train(epochs, batch_size=128, sample_interval=50):
- # 加载数据
- (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()
-
- # 将图片标准化到 [-1, 1] 区间内
- train_images = (train_images - 127.5) / 127.5
- # 数据
- train_images = np.expand_dims(train_images, axis=3)
-
- # 创建标签
- true = np.ones((batch_size, 1))
- fake = np.zeros((batch_size, 1))
-
- # 进行循环训练
- for epoch in range(epochs):
-
- # 随机选择 batch_size 张图片
- idx = np.random.randint(0, train_images.shape[0], batch_size)
- imgs = train_images[idx]
-
- # 生成噪音
- noise = np.random.normal(0, 1, (batch_size, latent_dim))
- # 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)
- gen_imgs = generator.predict(noise)
-
- # 训练鉴别器
- d_loss_true = discriminator.train_on_batch(imgs, true)
- d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
- # 返回loss值
- d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
-
- # 训练生成器
- noise = np.random.normal(0, 1, (batch_size, latent_dim))
- g_loss = combined.train_on_batch(noise, true)
-
- print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
-
- # 保存样例图片
- if epoch % sample_interval == 0:
- sample_images(epoch)
train(epochs=1200, batch_size=256, sample_interval=200)
- import imageio
-
- def compose_gif():
- # 图片地址
- data_dir = "images"
- data_dir = pathlib.Path(data_dir)
- paths = list(data_dir.glob('*'))
-
- gif_images = []
- for path in paths:
- print(path)
- gif_images.append(imageio.imread(path))
- imageio.mimsave("test.gif",gif_images,fps=2)
-
- compose_gif()
gif
- import os
- import numpy as np
- import matplotlib.pyplot as plt
- import random
-
- directory = "images/"
- images = random.choices(os.listdir(directory), k=6)
-
- fig = plt.figure(figsize=(10, 10))
- columns = 3
- rows = 2
-
- for x, i in enumerate(images):
- path = os.path.join(directory,i)
- img = plt.imread(path)
- fig.add_subplot(rows, columns, x+1)
- plt.imshow(img)
-
- plt.show()