• 自动编码器(AE)生成Mnist手写数字集,基于tensorflow和keras实现


    本文实现的是自动编码器(Auto Encoder,AE),而不是变分自动编码器(Variational Auto Encoder,VAE)。因此代码只能实现通过Mnist数据集自编码出一个相似的新的手写数字集,而不是实现通过输入随机高斯分布的隐含变量生成全新的手写数字。

    1、code

    # @Time    : 2022/8/22 21:21
    # @Author  : CSDN User: ctrl A_ctrl C_ctrl V
    # @Function: valid AE(Auto Encoder) using mnist dataset
    
    
    import tensorflow as tf
    import tensorflow.keras as keras
    import matplotlib.pyplot as plt
    import numpy as np
    import random
    
    # hyper parameter
    epochs = 10
    batchSize = 512
    
    # load dataset
    (x_train, _), (x_valid, _) = keras.datasets.mnist.load_data()
    assert x_train.shape == (60000, 28, 28)
    assert x_valid.shape == (10000, 28, 28)
    x_train = x_train.reshape(x_train.shape[0], -1)   # (60000,784)
    x_valid = x_valid.reshape(x_valid.shape[0], -1)
    
    # normalization
    x_train = tf.cast(x_train, tf.float32) / 255
    x_valid = tf.cast(x_valid, tf.float32) / 255
    
    # encoder and decoder layer
    inputSize = 784
    hiddenSize = 32
    outputSize = 784
    inputDim = keras.layers.Input(shape=(inputSize,))
    encodeLayer = keras.layers.Dense(hiddenSize, activation='relu')(inputDim)
    decoderLayer = keras.layers.Dense(outputSize, activation='sigmoid')(encodeLayer)
    
    # bulid model
    model = keras.Model(inputs=inputDim, outputs=decoderLayer)
    print(model.summary())
    
    # get encoder and decoder from model
    encoder = keras.Model(inputs=inputDim, outputs=encodeLayer)
    decoderInput = keras.layers.Input(shape=(hiddenSize,))
    decoderOutput = model.layers[-1](decoderInput)
    decoder = keras.Model(inputs=decoderInput, outputs=decoderOutput)
    
    # train
    # VAE是没有label的,以输入图像本身作为label,因此这里的 x=y=x_train
    model.compile(optimizer='adam', loss='mse')
    model.fit(x=x_train, y=x_train, epochs=epochs, batch_size=batchSize, shuffle=True, validation_data=(x_valid, x_valid))
    
    # valid
    # display ten images randomly for visualization
    encoder_valid = encoder.predict(x_valid)
    decoder_valid = decoder.predict(encoder_valid)
    x_valid = x_valid.numpy()
    visualNum = 10
    startNum = random.randint(0, 10000 - visualNum)
    plt.figure(figsize=(20, 4))
    for i in range(1, visualNum + 1):
        plt.subplot(2, visualNum, i)
        plt.imshow(x_valid[startNum + i].reshape(28, 28))
        plt.subplot(2, visualNum, visualNum + i)
        plt.imshow(decoder_valid[startNum + i].reshape(28, 28))
    plt.show()
    
    # test
    # test with random matrix,display ten images randomly for visualization
    test_tensor = np.random.rand(visualNum, hiddenSize)
    test_output = decoder.predict(test_tensor)
    plt.figure(figsize=(20, 4))
    for i in range(1, visualNum + 1):
        plt.subplot(2, visualNum, i)
        plt.imshow(test_output[i - 1].reshape(28, 28))
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73

    2、生成结果

    (1) epoch=1

    验证集(生成的图像比较模糊,但已经有基本轮廓):
    在这里插入图片描述
    用随机生成的矩阵进行测试(毫无规律):

    在这里插入图片描述

    (2) epoch=5

    验证集(生成的图像有一些模糊,但轮廓非常清晰):

    在这里插入图片描述
    用随机生成的矩阵进行测试(依然毫无规律):

    在这里插入图片描述

    (3) epoch=10

    验证集(生成的图像已经非常清晰):
    在这里插入图片描述
    用随机生成的矩阵进行测试(依然毫无规律):

    在这里插入图片描述

    (4)总结

    正如前面所言,AE只能复现图像,不能生成图像。所以随着epoch的增加,复现的图像越来越清晰,但无法通过随机矩阵生成我们想要的图像。要想实现真正的图像生成需要用VAE和GAN。

  • 相关阅读:
    【数据库】E-R图相关知识、绘制方法及工具推荐
    SOP作业指导书系统如何帮助厂家实现数字化转型
    一文1500字手把手教你Jmeter如何压测数据库【保姆级教程】
    白皮书 |得帆云低代码aPaaS X OA全新解决方案,解锁数字化协作新境界
    00_socket_demo
    排序算法之计数排序
    Java反射(Reflex)机制
    Unity实现设计模式——责任链模式
    MySQL - 联表查询从表即使有索引依然 ALL 的一个原因
    docker 部署lnmp
  • 原文地址:https://blog.csdn.net/qq_43799400/article/details/126480373