• python神经网络实现手写数字识别实验


         手写数字识别实验是机器学习中最常见的一个示例,可以有很多种办法实现,最基础的其实就是利用knn算法,根据数字图片对应矩阵与经过训练的数字进行距离计算,最后这个距离最短,那么就认为它是哪个数字。

         这里直接通过神经网络的办法来进行手写数字识别实验。不借助其他框架,编写网络,然后进行测试。这个代码其实网上有很多,并不是原创。

         这里有必要说明一下手写数字的数据集,这里采用的是mnist_dataset/mnist_train.csv数据集,数据地址: https://www.kaggle.com/datasets/oddrationale/mnist-in-csv。下载之后是一个压缩包,里面包含mnist_train.csv,mnist_test.csv。

        我们可以看看mnist_train.csv的部分数据:

     

        上图中,①处表示 第一行内容 其实是标题,我们在数据处理的时候需要过滤这一行。② 表示的是label内容,也就是真实数字,它由0-9组成,也就是10个分类。③ 处表示的28 * 28矩阵,这个数字由784个数字组成。 

        实验过程,先使用mnist_train.csv数据训练网络,然后利用我们自己手写的数字进行测试。这里没有使用mnist_test.csv进行测试,主要是它本身就是人家进行测试的数据,我们这里自己测试。

        我自己准备的数字图片如下所示:

        这些图片都是根据这里测试数据mnist_train.csv数据格式的要求进行绘制的28*28像素的图片,这个图片很小,但是可以借助windows系统paint绘图工具,选择28*28像素画布,然后进行放大,最后可以在编辑区域画出这些数字。

         

         下面给出代码:

    1. import os
    2. import numpy as np
    3. import scipy.special
    4. import imageio
    5. image_path = 'number_images'
    6. # 加载图片
    7. def load_img_number(root_dir):
    8. files = os.listdir(root_dir)
    9. file_list = []
    10. for file in files:
    11. file_path = os.path.join(root_dir, file)
    12. file_list.append(file_path)
    13. return file_list
    14. class neuralnetwork:
    15. def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
    16. # 输入层
    17. self.inodes = inputnodes
    18. # 隐藏层
    19. self.hnodes = hiddennodes
    20. # 输出层
    21. self.onodes = outputnodes
    22. # 学习率
    23. self.lr = learningrate
    24. # 输入层-隐藏层权重
    25. self.wih = (np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes)))
    26. # 隐藏层-输出层权重
    27. self.who = (np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes)))
    28. # 激活函数
    29. self.activation_function = lambda x: scipy.special.expit(x)
    30. def train(self, inputs_list, targets_list):
    31. inputs = np.array(inputs_list, ndmin=2).T
    32. targets = np.array(targets_list, ndmin=2).T
    33. hidden_inputs = np.dot(self.wih, inputs)
    34. hidden_outputs = self.activation_function(hidden_inputs)
    35. final_inputs = np.dot(self.who, hidden_outputs)
    36. final_outputs = self.activation_function(final_inputs)
    37. output_errors = targets - final_outputs
    38. hidden_errors = np.dot(self.who.T, output_errors)
    39. self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),
    40. np.transpose(hidden_outputs))
    41. self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))
    42. def query(self, inputs_list):
    43. inputs = np.array(inputs_list, ndmin=2).T
    44. hidden_inputs = np.dot(self.wih, inputs)
    45. hidden_outputs = self.activation_function(hidden_inputs)
    46. final_inputs = np.dot(self.who, hidden_outputs)
    47. final_outputs = self.activation_function(final_inputs)
    48. return final_outputs
    49. input_nodes = 784
    50. hidden_nodes = 200
    51. output_nodes = 10
    52. learning_rate = 0.2
    53. # 构建模型
    54. model = neuralnetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
    55. # 准备训练数据
    56. training_data_file = open('mnist/mnist_train.csv', 'r')
    57. training_data_list = training_data_file.readlines()
    58. # 去掉第一行标题
    59. training_data_list = training_data_list[1:]
    60. training_data_file.close()
    61. # 训练
    62. for record in training_data_list:
    63. all_values = record.split(',')
    64. inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
    65. targets = np.zeros(output_nodes) + 0.01
    66. targets[int(all_values[0])] = 0.99
    67. model.train(inputs, targets)
    68. pass
    69. img_list = load_img_number(image_path)
    70. for i in range(len(img_list)):
    71. img_name = img_list[i]
    72. img_arr = imageio.v2.imread(img_name, mode='L')
    73. img_data = 255.0 - img_arr.reshape(784)
    74. inputs = (img_data / 255.0 * 0.99) + 0.01
    75. outputs = model.query(inputs)
    76. label = np.argmax(outputs)
    77. print(f'{img_name} 识别结果是 {label}')

        运行代码,打印结果:

     

        1、识别率很感人,其实很多都识别错误。 

        2、多次运行,结果也不一样。

        3、识别不正确的基本会认为6或者8。不知道怎么会有这种奇怪的结果。

       最后,给出本示例的代码和资源:https://gitee.com/buejee/aitutorial

  • 相关阅读:
    线性代数学习笔记8-4:正定矩阵、二次型的几何意义、配方法与消元法的联系、最小二乘法与半正定矩阵A^T A
    Python | GUI | tinker不完全总结
    数字化转型企业成功的关键,用数据创造价值
    简化Microsoft365审核
    Android 10.0 Launcher3定制化之folder文件夹文件居中显示的功能实现
    运行游戏“找不到XINPUTI_3.dll无法继续执行代码,总共有五种解决方案
    IBM车库创新:为科技创新头号工程打造共创引擎
    快速部署 MySQL InnoDB Cluster
    金仓数据库KingbaseES数据库参考手册(服务器配置参数13. 锁管理)
    如何使用 Pyinstaller 编译打包 Python 项目生成 exe 可执行文件(2023 年最新详细教程)
  • 原文地址:https://blog.csdn.net/feinifi/article/details/130854741