• mnist数据集


    训练模型

    import tensorflow as tf
    
    import keras
    from keras.models import Sequential
    from keras.layers import Dense,Dropout, Flatten,Conv2D, MaxPooling2D
    # from keras.optimizers import SGD
    from tensorflow.keras.optimizers import Adam,Nadam, SGD
    
    
    from PIL import Image
    import numpy as np
    import matplotlib.pyplot as plt
    
    print('tf',tf.__version__)
    print('keras',keras.__version__)
    
    
    # batch大小,每处理128个样本进行一次梯度更新
    batch_size = 64
    # 训练素材类别数
    num_classes = 10
    # 迭代次数
    epochs = 5
    
    
    
    f = np.load("mnist.npz")
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    print(x_train.shape,"  ",y_train.shape)
    print(x_test.shape,"  ",y_test.shape)
    
    # im=plt.imshow(x_train[0],cmap="gray")
    # plt.show()
    
    ## 维度合并784 = 28*28
    x_train = x_train.reshape(60000, 784).astype('float32')
    x_test = x_test.reshape(10000, 784).astype('float32')
    
    ## 归一化,像素点的值 转成 0-1 之间的数字
    x_train /= 255
    x_test /= 255
    
    # print(x_train[0])
    
    # 标签转换为独热码
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)
    # print(y_train[0]) ## 类似 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
    
    print(x_train.shape,"  ",y_train.shape)
    print(x_test.shape,"  ",y_test.shape)
    
    # 构建模型
    model = Sequential()
    model.add(Dense(512, activation='relu',input_shape=(784,)))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))
    model.summary()
    
    # [编译模型] 配置模型,损失函数采用交叉熵,优化采用Adadelta,将识别准确率作为模型评估
    model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adadelta(), metrics=['accuracy'])
    #  validation_data为验证集
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test))
    
    # 开始评估模型效果 # verbose=0为不输出日志信息
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1]) # 准确度
    
    
    model.save('mnist_model_weights.h5') # 保存训练模型
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74

    训练结果

    Epoch 4/5
    938/938 [==============================] - 7s 7ms/step - loss: 1.5926 - accuracy: 0.7292 - val_loss: 1.4802 - val_accuracy: 0.7653
    Epoch 5/5
    938/938 [==============================] - 6s 6ms/step - loss: 1.4047 - accuracy: 0.7686 - val_loss: 1.2988 - val_accuracy: 0.7918
    Test loss: 1.2988097667694092
    Test accuracy: 0.7918000221252441
    
    Process finished with exit code 0
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    测试模型

    
    import tensorflow as tf
    
    from PIL import Image
    import numpy as np
    from keras.models import load_model
    
    # 构建模型
    model = load_model('mnist_model_weights.h5') # 加载训练模型
    # model.summary()
    
    def read_image(img_name):
        im = Image.open(img_name).resize((28,28),Image.ANTIALIAS).convert('L') # 将要识别的图缩放到训练图一样的大小,并且灰度化
        data = np.array(im)
        return data
    
    images=[]
    images.append(read_image("test.png"))
    # print(images)
    
    X = np.array(images)
    print(X.shape)
    X=X.reshape(1, 784).astype('float32')
    print(X.shape)
    X /=255
    # print(X[0:1])
    result=model.predict(X[0:1])[0] # 识别出第一张图的结果,多张图的时候,把后面的[0] 去掉,返回的就是多张图结果
    num=0 # 用来分析预测的结果
    for i in range(len(result)): # result的长度是10
        # print(result[i]*255)
        if result[i]*255>result[num]*255: # 值越大,就越可能是结果
            num=i
    
    print("预测结果",num)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    将数据集转换为图片

    
    #coding: utf-8
    import os
    import tensorflow as tf
    import input_data
    from PIL import Image
    
    '''
    函数功能:按照bmp格式提取mnist数据集中的图片
    参数介绍:
        mnist_dir   mnist数据集存储的路径
        save_dir    提取结果存储的目录
    '''
    
    mint=tf.keras.datasets.mnist
    
    def extract_mnist(mnist_dir, save_dir):
        rows = 28
        cols = 28
    
        # 加载mnist数据集
        # one_hot = True为默认打开"独热编码"
        mnist = input_data.read_data_sets(mnist_dir, one_hot=False)
        # 获取训练图片数量
        shape = mnist.train.images.shape
        images_train_count = shape[0]
        pixels_count_per_image = shape[1]
        # 获取训练标签数量=训练图片数量
        # 关闭"独热编码"后,labels的类型为[7 3 4 ... 5 6 8]
        labels = mnist.train.labels
        print(labels)
        exit(0)
        labels_train_count = labels.shape[0]
    
        if (images_train_count == labels_train_count):
            print("训练集共包含%d张图片,%d个标签" % (images_train_count, labels_train_count))
            print("每张图片包含%d个像素" % (pixels_count_per_image))
            print("数据类型为", mnist.train.images.dtype)
    
            # mnist图像数值的范围为[0,1], 需将其转换为[0,255]
            for current_image_id in range(images_train_count):
                for i in range(pixels_count_per_image):
                    if mnist.train.images[current_image_id][i] != 0:
                        mnist.train.images[current_image_id][i] = 255
    
                if ((current_image_id + 1) % 50) == 0:
                    print("已转换%d张,共需转换%d张" %
                          (current_image_id + 1, images_train_count))
    
            # 创建train images的保存目录, 按标签保存
            for i in range(10):
                dir = "%s/%s" % (save_dir, i)
                print(dir)
                if not os.path.exists(dir):
                    os.mkdir(dir)
    
            # indices = [0, 0, 0, ..., 0]用来记录每个标签对应的图片数量
            indices = [0 for x in range(0, 10)]
            for i in range(images_train_count):
                new_image = Image.new("L", (cols, rows))
                # 遍历new_image 进行赋值
                for r in range(rows):
                    for c in range(cols):
                        new_image.putpixel(
                            (r, c), int(mnist.train.images[i][c + r * cols]))
    
                # 获取第i张训练图片对应的标签
                label = labels[i]
                image_save_path = "%s/%s/%s.bmp" % (save_dir, label,
                                                    indices[label])
                indices[label] += 1
                new_image.save(image_save_path)
    
                # 打印保存进度
                if ((i + 1) % 50) == 0:
                    print("图片保存进度: 已保存%d张,共需保存%d张" % (i + 1, images_train_count))
        else:
            print("图片数量与标签数量不一致!")
    
    
    if __name__ == '__main__':
        mnist_dir = "Mnist_Data"
        save_dir = "Mnist_Data_TrainImages"
        extract_mnist(mnist_dir, save_dir)
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86

    利用图片制作mnist格式数据集

    import os
    from PIL import Image
    from array import *
    from random import shuffle
    
    # # 文件组织架构:
    # ├──training-images
    # │   └──0(类别为0的图像)
    # │   ├──1(类别为1的图像)
    # │   ├──2(类别为2的图像)
    # │   ├──3(类别为3的图像)
    # │   └──4(类别为4的图像)
    # ├──test-images
    # │   └──0(类别为0的图像)
    # │   ├──1(类别为1的图像)
    # │   ├──2(类别为2的图像)
    # │   ├──3(类别为3的图像)
    # │   └──4(类别为4的图像)
    # └── mnist数据集制作.py(本脚本)
    
    # Load from and save to
    Names = [['./training-images', 'train'], ['./test-images', 'test']]
    
    for name in Names:
    
        data_image = array('B')
        data_label = array('B')
    
        print(os.listdir(name[0]))
        FileList = []
        for dirname in os.listdir(name[0])[0:]:  # [1:] Excludes .DS_Store from Mac OS
            # print(dirname)
            path = os.path.join(name[0], dirname)
            # print(path)
            for filename in os.listdir(path):
                # print(filename)
                if filename.endswith(".png"):
                    FileList.append(os.path.join(name[0] + '/', dirname + '/', filename))
            print(FileList)
        shuffle(FileList)  # Usefull for further segmenting the validation set
    
        for filename in FileList:
    
            label = int(filename.split('/')[2])
            print(filename)
            Im = Image.open(filename)
            # print(Im)
    
            pixel = Im.load()
    
            width, height = Im.size
    
            for x in range(0, width):
                for y in range(0, height):
                    data_image.append(pixel[y, x])
    
            data_label.append(label)  # labels start (one unsigned byte each)
    
        hexval = "{0:#0{1}x}".format(len(FileList), 6)  # number of files in HEX
    
        # header for label array
    
        header = array('B')
        header.extend([0, 0, 8, 1, 0, 0])
        header.append(int('0x' + hexval[2:][:2], 16))
        header.append(int('0x' + hexval[2:][2:], 16))
    
        data_label = header + data_label
    
        # additional header for images array
    
        if max([width, height]) <= 256:
            header.extend([0, 0, 0, width, 0, 0, 0, height])
        else:
            raise ValueError('Image exceeds maximum size: 256x256 pixels');
    
        header[3] = 3  # Changing MSB for image data (0x00000803)
    
        data_image = header + data_image
    
        output_file = open(name[1] + '-images-idx3-ubyte', 'wb')
        data_image.tofile(output_file)
        output_file.close()
    
        output_file = open(name[1] + '-labels-idx1-ubyte', 'wb')
        data_label.tofile(output_file)
        output_file.close()
    
    # 运行脚本得到四个文件test-images-idx3-ubyte、test-labels-idx1-ubyte、train-images-idx3-ubyte、train-labels-idx1-ubyte
    # 在cmd中利用gzip -c train-labels-idx1-ubyte > train-labels-idx1-ubyte.gz命令对上述四个文件压缩得到最终的mnist格式数据集
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
  • 相关阅读:
    CF1036C Classy Numbers 题解
    java操作文件,生成文件,读取文件,按顺序读取文件
    mysql主从复制
    【全网最细】自动化测试注意事项+问题点汇总,不要再走弯路了...
    【iOS】音频中断
    用DIV+CSS技术设计的抗击疫情网页与实现制作(web前端网页制作课作业)
    Unity 3D视频教程
    接口自动化中cookies的处理技术
    金融时间序列模型
    GPU 基础知识整理
  • 原文地址:https://blog.csdn.net/vive921/article/details/133151075