• 【深度学习实践(八)】生成对抗网络(GAN)之手写数字生成


    活动地址:CSDN21天学习挑战赛

      👉引言💎

    学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。
    热爱写作,愿意让自己成为更好的人…


    在这里插入图片描述

    铭记于心
    🎉✨🎉我唯一知道的,便是我一无所知🎉✨🎉

    【深度学习实践(八)】对抗生成网络(GAN)之手写数字生成

    一、🌹对抗生成网络

    1 定义与背景 :

    生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构,GAN的核心本质是通过对抗训练将随机噪声的分布拉近到真实的数据分布

    2 基本结构:
    • GAN本身是一个不断博弈,识别真假的过程,下面通过手写数字生成案例 窥探GAN对抗生成网络的原理及操作流程:

    在这里插入图片描述

    • 定义一个模型来作为生成器(图三中蓝色部分Generator),能够输入一个向量,输出手写数字大小的像素图像(生成噪声

    • 定义一个分类器来作为判别器(图三中红色部分Discriminator)用来判别图片是真的还是假的(或者说是来自数据集中的还是生成器中生成的),输入为手写图片,输出为判别图片的标签

    并且,既然是神经网络,那么模型就可以根据 外界反馈 自行调整参数,也就是会根据 标签匹配结果进行相应的学习与调整, 训练完成后 可以达到 ** 生成以假乱真的 手写数字图片效果**

    二、🌹模型训练

    💎1 设置GPU
    • GPU能够为大量数据的运算提供算力支持
    import tensorflow as tf
    gpus = tf.config.list_physical_devices("GPU")
    
    if gpus:
        gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
        tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
        tf.config.set_visible_devices([gpu0],"GPU")
        
    warnings.filterwarnings("ignore")             
    plt.rcParams['font.sans-serif'] = ['SimHei']  
    plt.rcParams['axes.unicode_minus'] = False    
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    💎2 构建GAN对抗网络生成器
    def build_generator():
        # ======================================= #
        #     生成器,输入一串随机数字生成图片
        # ======================================= #
        model = Sequential([
            layers.Dense(256, input_dim=latent_dim),
            layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数
            layers.BatchNormalization(momentum=0.8),   # BN 归一化
            
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            
            layers.Dense(1024),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            
            layers.Dense(np.prod(img_shape), activation='tanh'),
            layers.Reshape(img_shape)
        ])
    
        noise = layers.Input(shape=(latent_dim,))
        img = model(noise)
    
        return Model(noise, img)
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    💎3 构造鉴别器
    def build_discriminator():
        model = Sequential([
            layers.Flatten(input_shape=img_shape),
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(256),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(1, activation='sigmoid')
        ])
    
        img = layers.Input(shape=img_shape)
        validity = model(img)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 最后传入img以及model参数构造Model对象
      return Model(img, validity)
    • 鉴别器训练原理:通过对输入的图片进行鉴别,从而达到提升的效果

    • 生成器训练原理:通过鉴别器对其生成的图片进行鉴别,来实现提升


    💎4 构造生成器
    # 创建判别器
    dis = build_discriminator()
    
    # 定义优化器
    optimizer = tf.keras.optimizers.Adam(1e-4)
    dis.compile(loss='binary_crossentropy',
                          optimizer=optimizer,
                          metrics=['accuracy'])
                          
    # 创建生成器                       
    generator = build_generator()
    gan_input = layers.Input(shape=(latent_dim,))
    img = generator(gan_input) 
    
    #训练generate时候停止训练判别器
    dis.trainable = False  
    
    # 测试:对生成的假图片进行预测 
    validity = discriminator(img)
    combined = Model(gan_input, validity)
    combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    💎5 训练模型
    • train_on_batch详解:

      keras在compile完模型后需要训练,除了常用的model.fit()与model.fit_generator外

      还有model.train_on_bantch作用:对一批样品进行单梯度更新,即对一个epoch中的一个样本进行一次训练

    • 使用train_on_batch优点:

      更精细自定义训练过程,更精准的收集 loss 和 metrics

      分布训练模型-GAN生成对抗神经网络的实现

      多GPU训练保存模型更加方便

      def train(epochs, batch_size=128, sample_interval=50):
      
      • 1
    • 加载数据

      (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()
      
      • 1
    • 将图片标准化到 [-1, 1] 区间内

      train_images = (train_images - 127.5) / 127.5
      
      • 1
    • 数据

      train_images = np.expand_dims(train_images, axis=3)
      
      • 1
    • 创建标签

      true = np.ones((batch_size, 1))
      fake = np.zeros((batch_size, 1))
      
      • 1
      • 2
    • 开始训练

      for epoch in range(epochs): 
      
          
              idx = np.random.randint(0, train_images.shape[0], batch_size)
              imgs = train_images[idx]      
              
          
              noise = np.random.normal(0, 1, (batch_size, latent_dim))
       
              gen_imgs = generator.predict(noise)
              
      
              d_loss_true = discriminator.train_on_batch(imgs, true)
              d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
              
              d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
      
              noise = np.random.normal(0, 1, (batch_size, latent_dim))
              g_loss = combined.train_on_batch(noise, true)
              
              print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
      
              # 保存样例图片
              if epoch % sample_interval == 0:
                  sample_images(epoch)
      
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26

    在这里插入图片描述

    • 动图展示

      def compose_gif():
          # 图片地址
          data_dir = "F:/jupyter notebook/DL-100-days/code/images"
          data_dir = pathlib.Path(data_dir)
          paths    = list(data_dir.glob('*'))
          
          gif_images = []
          for path in paths:
              print(path)
              gif_images.append(imageio.imread(path))
          imageio.mimsave("test.gif",gif_images,fps=2)
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11

    🌹写在最后💖
    路漫漫其修远兮,吾将上下而求索!伙伴们,再见!🌹🌹🌹在这里插入图片描述

  • 相关阅读:
    NumPy 均匀分布模拟及 Seaborn 可视化教程
    Android一键锁屏,去除锁屏密码
    Kubernetes云原生实战03 搭建高可用负载均衡器(Keepalived 和 HAproxy)
    ARM32开发——第一盏灯
    serialVersionUID的重要性,及Idea自动生成 serialVersionUID的设置
    什么是RPC?RPC框架dubbo的核心流程
    Electron实战之进程间通信
    MyBatis和MyBatis-Plus的差别和优缺点
    qt udp tcp代替RPC(二) 图片AI服务器
    Java-微服务-谷粒商城-1-环境搭建&项目初始化
  • 原文地址:https://blog.csdn.net/runofsun/article/details/126446166