• TenorFlow多层感知机识别手写体


    GITHUB地址https://github.com/fz861062923/TensorFlow
    注意下载数据连接的是外网,有一股神秘力量让你403

    数据准备

    import tensorflow as tf
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    • 1
    • 2
    • 3
    • 4
    C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
      from ._conv import register_converters as _register_converters
    
    
    WARNING:tensorflow:From :4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
    WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please write your own downloading logic.
    WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use tf.data to implement this functionality.
    Extracting MNIST_data/train-images-idx3-ubyte.gz
    WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use tf.data to implement this functionality.
    Extracting MNIST_data/train-labels-idx1-ubyte.gz
    WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use tf.one_hot on tensors.
    WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry..wrap..wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use urllib or similar directly.
    Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
    Extracting MNIST_data/t10k-images-idx3-ubyte.gz
    Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
    Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
    WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    print('train images     :', mnist.train.images.shape,
          'labels:'           , mnist.train.labels.shape)
    print('validation images:', mnist.validation.images.shape,
          ' labels:'          , mnist.validation.labels.shape)
    print('test images      :', mnist.test.images.shape,
          'labels:'           , mnist.test.labels.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    train images     : (55000, 784) labels: (55000, 10)
    validation images: (5000, 784)  labels: (5000, 10)
    test images      : (10000, 784) labels: (10000, 10)
    
    • 1
    • 2
    • 3

    建立模型

    def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为None
        W = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重W
        b = tf.Variable(tf.random_normal([1, output_dim]))
        XWb = tf.matmul(inputs, W) + b
        if activation is None:
            outputs = XWb
        else:
            outputs = activation(XWb)
        return outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    建立输入层 x
    x = tf.placeholder("float", [None, 784])
    
    • 1
    建立隐藏层h1
    h1=layer(output_dim=1000,input_dim=784,
             inputs=x ,activation=tf.nn.relu)  
    
    
    • 1
    • 2
    • 3
    建立隐藏层h2
    h2=layer(output_dim=1000,input_dim=1000,
             inputs=h1 ,activation=tf.nn.relu)  
    
    • 1
    • 2
    建立输出层
    y_predict=layer(output_dim=10,input_dim=1000,
                    inputs=h2,activation=None)
    
    • 1
    • 2

    定义训练方式

    建立训练数据label真实值 placeholder
    y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
    
    • 1
    定义loss function
    # 深度学习模型的训练中使用交叉熵训练的效果比较好
    loss_function = tf.reduce_mean(
                       tf.nn.softmax_cross_entropy_with_logits_v2
                           (logits=y_predict , 
                            labels=y_label))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    选择optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \
                        .minimize(loss_function)
    #使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化
    
    • 1
    • 2
    • 3

    定义评估模型的准确率

    计算每一项数据是否正确预测
    correct_prediction = tf.equal(tf.argmax(y_label  , 1),
                                  tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
    
    • 1
    • 2
    将计算预测正确结果,加总平均
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    
    • 1

    开始训练

    trainEpochs = 15#执行15个训练周期
    batchSize = 100#每一批的数量为100
    totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
    epoch_list=[];accuracy_list=[];loss_list=[];
    from time import time
    startTime=time()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    for epoch in range(trainEpochs):
        #执行15个训练周期
        #每个训练周期执行550批次训练
        for i in range(totalBatchs):
            batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据
            sess.run(optimizer,feed_dict={x: batch_x,
                                          y_label: batch_y})
            
        #使用验证数据计算准确率
        loss,acc = sess.run([loss_function,accuracy],
                            feed_dict={x: mnist.validation.images, #验证数据的features
                                       y_label: mnist.validation.labels})#验证数据的label
    
        epoch_list.append(epoch)
        loss_list.append(loss);accuracy_list.append(acc)    
        
        print("Train Epoch:", '%02d' % (epoch+1), \
              "Loss=","{:.9f}".format(loss)," Accuracy=",acc)
        
    duration =time()-startTime
    print("Train Finished takes:",duration)        
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    Train Epoch: 01 Loss= 133.117172241  Accuracy= 0.9194
    Train Epoch: 02 Loss= 88.949943542  Accuracy= 0.9392
    Train Epoch: 03 Loss= 80.701606750  Accuracy= 0.9446
    Train Epoch: 04 Loss= 72.045913696  Accuracy= 0.9506
    Train Epoch: 05 Loss= 71.911483765  Accuracy= 0.9502
    Train Epoch: 06 Loss= 63.642936707  Accuracy= 0.9558
    Train Epoch: 07 Loss= 67.192626953  Accuracy= 0.9494
    Train Epoch: 08 Loss= 55.959281921  Accuracy= 0.9618
    Train Epoch: 09 Loss= 58.867351532  Accuracy= 0.9592
    Train Epoch: 10 Loss= 61.904548645  Accuracy= 0.9612
    Train Epoch: 11 Loss= 58.283069611  Accuracy= 0.9608
    Train Epoch: 12 Loss= 54.332244873  Accuracy= 0.9646
    Train Epoch: 13 Loss= 58.152175903  Accuracy= 0.9624
    Train Epoch: 14 Loss= 51.552104950  Accuracy= 0.9688
    Train Epoch: 15 Loss= 52.803482056  Accuracy= 0.9678
    Train Finished takes: 545.0556836128235
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    画出误差执行结果
    %matplotlib inline
    import matplotlib.pyplot as plt
    fig = plt.gcf()#获取当前的figure图
    fig.set_size_inches(4,2)#设置图的大小
    plt.plot(epoch_list, loss_list, label = 'loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['loss'], loc='upper left')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    
    
    • 1

    外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

    画出准确率执行结果
    plt.plot(epoch_list, accuracy_list,label="accuracy" )
    fig = plt.gcf()
    fig.set_size_inches(4,2)
    plt.ylim(0.8,1)
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

    评估模型的准确率

    print("Accuracy:", sess.run(accuracy,
                               feed_dict={x: mnist.test.images, 
                                          y_label: mnist.test.labels}))
    
    • 1
    • 2
    • 3
    Accuracy: 0.9643
    
    • 1

    进行预测

    prediction_result=sess.run(tf.argmax(y_predict,1),
                               feed_dict={x: mnist.test.images })
    
    • 1
    • 2
    prediction_result[:10]
    
    • 1
    array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
    
    • 1
    import matplotlib.pyplot as plt
    import numpy as np
    def plot_images_labels_prediction(images,labels,
                                      prediction,idx,num=10):
        fig = plt.gcf()
        fig.set_size_inches(12, 14)
        if num>25: num=25 
        for i in range(0, num):
            ax=plt.subplot(5,5, 1+i)
            
            ax.imshow(np.reshape(images[idx],(28, 28)), 
                      cmap='binary')
                
            title= "label=" +str(np.argmax(labels[idx]))
            if len(prediction)>0:
                title+=",predict="+str(prediction[idx]) 
                
            ax.set_title(title,fontsize=10) 
            ax.set_xticks([]);ax.set_yticks([])        
            idx+=1 
        plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    plot_images_labels_prediction(mnist.test.images,
                                  mnist.test.labels,
                                  prediction_result,0)
    
    • 1
    • 2
    • 3

    外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

    y_predict_Onehot=sess.run(y_predict,
                              feed_dict={x: mnist.test.images })
    
    • 1
    • 2
    y_predict_Onehot[8]
    
    • 1
    array([-6185.544  , -5329.589  ,  1897.1707 , -3942.7764 ,   347.9809 ,
            5513.258  ,  6735.7153 , -5088.5273 ,   649.2062 ,    69.50408],
          dtype=float32)
    
    • 1
    • 2
    • 3

    找出预测错误

    for i in range(400):
        if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
            print("i="+str(i)+"   label=",np.argmax(mnist.test.labels[i]),
                  "predict=",prediction_result[i])
    
    • 1
    • 2
    • 3
    • 4
    i=8   label= 5 predict= 6
    i=18   label= 3 predict= 8
    i=149   label= 2 predict= 4
    i=151   label= 9 predict= 8
    i=233   label= 8 predict= 7
    i=241   label= 9 predict= 8
    i=245   label= 3 predict= 5
    i=247   label= 4 predict= 2
    i=259   label= 6 predict= 0
    i=320   label= 9 predict= 1
    i=340   label= 5 predict= 3
    i=381   label= 3 predict= 7
    i=386   label= 6 predict= 5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    sess.close()
    
    • 1
  • 相关阅读:
    基于51单片机的万年历设计
    用核心AI资产打造稀缺电竞体验,顺网灵悉背后有一盘大棋
    高数中值定理总结
    二叉树的概念和性质
    Conmi的正确答案——米家第一个ReactNative程序开发记录
    解决:虚拟机远程连接失败
    汇编语言实现for循环?怎么实现的,形象的比喻
    智牛股项目--02
    FactoryBean原理
    Apollo在Java中的使用
  • 原文地址:https://blog.csdn.net/weixin_41503009/article/details/86767531