• tensorflow跑手写体实验


    目录

    1、环境条件

    2、代码实现

    3、总结


    1、环境条件

    1. pycharm编译器
    2. python3.0环境
    3. tensorflow2.0依赖
    4. matplotlib依赖(用于画图)

    2、代码实现

    1. import tensorflow as tf
    2. from tensorflow.keras.datasets import mnist
    3. from tensorflow.keras.preprocessing import image
    4. import numpy as np
    5. import matplotlib.pyplot as plt
    6. # 加载并预处理 MNIST 数据集
    7. (x_train, y_train), (x_test, y_test) = mnist.load_data()
    8. x_train, x_test = x_train / 255.0, x_test / 255.0
    9. print(x_train)
    10. print(x_test)
    11. # 构建 LeNet-5 模型
    12. model = tf.keras.models.Sequential([
    13. tf.keras.layers.Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)),
    14. tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    15. tf.keras.layers.Conv2D(64, kernel_size=(5, 5), activation='relu'),
    16. tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    17. tf.keras.layers.Flatten(),
    18. tf.keras.layers.Dense(120, activation='relu'),
    19. tf.keras.layers.Dense(84, activation='relu'),
    20. tf.keras.layers.Dense(10, activation='softmax')
    21. ])
    22. model.compile(optimizer='adam',
    23. loss='sparse_categorical_crossentropy',
    24. metrics=['accuracy'])
    25. # 重塑数据以适应模型
    26. x_train = x_train.reshape(-1, 28, 28, 1)
    27. x_test = x_test.reshape(-1, 28, 28, 1)
    28. # 训练模型
    29. model.fit(x_train, y_train, epochs=5)
    30. # 评估模型
    31. test_loss, test_acc = model.evaluate(x_test, y_test)
    32. print(f'测试准确率: {test_acc}')
    33. # 保存模型
    34. model.save('lenet-5_model.h5')
    35. print('模型已保存至 lenet-5_model.h5')
    36. # 加载模型
    37. loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
    38. print('模型已加载')
    39. # 加载并预处理本地图片
    40. def load_and_preprocess_image(image_path):
    41. img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
    42. img_array = image.img_to_array(img)
    43. img_array = img_array / 255.0 # 归一化
    44. img_array = np.expand_dims(img_array, axis=0) # 添加批次维度
    45. return img_array
    46. # 预测本地图片
    47. image_path = '4.png' # 替换为你的本地图片路径
    48. img_array = load_and_preprocess_image(image_path)
    49. # 使用加载的模型进行预测
    50. predictions = loaded_model.predict(img_array)
    51. predicted_label = np.argmax(predictions)
    52. # 打印预测结果
    53. print(f'预测结果: {predicted_label}')
    54. # 显示图片
    55. plt.imshow(img_array[0, :, :, 0], cmap='gray')
    56. plt.title(f'预测结果: {predicted_label}')
    57. plt.show()

            解释:image_path为本地图片路径,通过model.save()方法实现模型的保存功能,下次预测使用的时候直接使用训练好的模型即可。下面将给出可直接预测的代码:

    1. import tensorflow as tf
    2. from tensorflow.keras.preprocessing import image
    3. import numpy as np
    4. import matplotlib.pyplot as plt
    5. from matplotlib.font_manager import FontProperties
    6. # 加载模型
    7. loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
    8. print('模型已加载')
    9. # 加载并预处理本地图片
    10. def load_and_preprocess_image(image_path):
    11. img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
    12. img_array = image.img_to_array(img)
    13. img_array = img_array / 255.0 # 归一化
    14. img_array = np.expand_dims(img_array, axis=0) # 添加批次维度
    15. return img_array
    16. # 预测本地图片
    17. image_path = '7.png' # 替换为你的本地图片路径
    18. img_array = load_and_preprocess_image(image_path)
    19. # 使用加载的模型进行预测
    20. predictions = loaded_model.predict(img_array)
    21. predicted_label = np.argmax(predictions)
    22. # 打印预测结果
    23. print(f'预测结果: {predicted_label}')
    24. # 设置支持中文的字体
    25. font_path = "C:/Windows/Fonts/simhei.ttf" # 替换为你的字体路径,例如 SimHei.ttf
    26. font_prop = FontProperties(fname=font_path)
    27. # 显示图片
    28. plt.imshow(img_array[0, :, :, 0], cmap='gray')
    29. plt.title(f'预测结果: {predicted_label}', fontproperties=font_prop)
    30. plt.show()

    3、总结

            使用tensorflow完成手写体图片的识别功能,其主要难点在安装依赖环境,其他的都是比较简单的事情。

    学习之所以会想睡觉,是因为那是梦开始的地方。
    ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)
                                                                                                            ------不写代码不会凸的小刘

  • 相关阅读:
    java 3至5年常见面试题及答案
    文件预览服务器kkfileview安装部署(linux 版)
    详解RFC 793文档-2
    SpringBoot的基本使用
    VUE指令、computed计算属性和watch 侦听器(附带详细案例)
    将字体颜色设置为渐变色 --字体倾斜--数组转字符串--旋转(一些样式的设置)
    跟着GPT学设计模式之桥接模式
    tomcat目录下创建临时文件,长时间没有使用会被系统清理掉
    【金九银十必问面试题】这应该是面试官最想听到的回答,Mysql如何解决幻读问题?
    Docker初认识
  • 原文地址:https://blog.csdn.net/qq_40834643/article/details/140102171