• tensorflow的简单使用


    参考文章:

    https://www.tensorflow.org/tutorials/keras/classification

    https://www.datacamp.com/tutorial/tensorflow-tutorial

    一、引入模块

            

    1. import tensorflow as tf
    2. import tensorflow.keras as keras
    3. import numpy as np
    4. import matplotlib.pyplot as plt
    5. import skimage
    6. from skimage import transform
    7. from skimage.color import rgb2gray

    二、加载数据集

            1、使用minist数据集

                    

    1. fashion_mnist = keras.datasets.fashion_mnist
    2. (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

            2、自定义数据集

            

    1. def load_data(data_directory):
    2. directories = [d for d in os.listdir(data_directory)
    3. if os.path.isdir(os.path.join(data_directory, d))]
    4. labels = []
    5. images = []
    6. all_image_paths = []
    7. for d in directories:
    8. label_directory = os.path.join(data_directory, d)
    9. file_names = [os.path.join(label_directory, f)
    10. for f in os.listdir(label_directory)
    11. # if f.endswith(".ppm")
    12. ]
    13. for f in file_names:
    14. images.append(skimage.io.imread(f))
    15. labels.append(int(d))
    16. all_image_paths.append(f)
    17. return images, labels
    18. ROOT_PATH = "E:/dataset/"
    19. train_data_directory = os.path.join(ROOT_PATH, "test/training")
    20. test_data_directory = os.path.join(ROOT_PATH, "test/testing")
    21. train_images, train_labels = load_data(train_data_directory)
    22. test_images, test_labels = load_data(test_data_directory)
    23. # 训练数据
    24. train_images = [transform.resize(image, (28, 28)) for image in train_images]
    25. train_images = np.array(train_images)
    26. train_images = rgb2gray(train_images)
    27. train_labels = np.array(train_labels)
    28. # 测试数据
    29. test_images = [transform.resize(image, (28, 28)) for image in test_images]
    30. test_images = np.array(test_images)
    31. test_images = rgb2gray(test_images)
    32. test_labels = np.array(test_labels)

    三、构建模型

            

    1. # 创建模型
    2. model = keras.Sequential([
    3. keras.layers.Flatten(input_shape=(28, 28)),
    4. keras.layers.Dense(128, activation='relu'), # 激活函数 relu
    5. keras.layers.Dense(128)
    6. ])
    7. # 优化器
    8. model.compile(optimizer='adam',
    9. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    10. metrics=['accuracy'])
    11. # 训练模型
    12. model.fit(train_images, train_labels, epochs=100) # epochs迭代次数
    13. # 评估准确率
    14. test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
    15. print('\nTest test_loss:{0} accuracy:{1}'.format(test_loss, test_acc))

    四、进行预测

            

    1. probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])
    2. predictions = probability_model.predict(test_images)
    3. pre_value = np.argmax(predictions[2]) # 预测标签
    4. truth_value = test_labels[2] # 实际标签
    5. print('\nTest pre_value:{0} truth_value:{1}'.format(pre_value, truth_value))
    6. # 可视化 流程图
    7. tf.keras.utils.plot_model(model, 'model1.png', show_shapes=True, show_dtype=True,show_layer_names=True)
    8. # 查看图片
    9. fig = plt.figure(figsize=(10, 10))
    10. for i in range(10):
    11. truth = test_labels[i]
    12. prediction = np.argmax(predictions[i])
    13. plt.subplot(3, 4, 1 + i) # 五行 两列
    14. plt.axis('off') # 坐标轴 off关闭坐标轴
    15. color = 'green' if truth == prediction else 'red'
    16. plt.text(40, 10, "Truth: {0}\nPrediction: {1}".format(truth, prediction),
    17. fontsize=12, color=color)
    18. plt.imshow(test_images[i], cmap="gray") # gray显示灰度图 默认显示热图
    19. # 显示图片
    20. plt.show()

            

  • 相关阅读:
    【操作系统】数据校验码——奇偶校验和海明校验
    深度解读DBSCAN聚类算法:技术与实战全解析
    爬虫抓取网站数据
    解决eclipse中的Java文件,使用idea打开的乱码问题
    《Python+Kivy(App开发)从入门到实践》自学笔记:高级UX部件——DropDown下拉列表
    用于NLP领域的排序模型最佳实践
    GIS前端-地图操作与交互
    境电商为什么要做独立站?API一键对接秒上架瞬间拥有全平台几十亿商品和用户!
    【Azure 云服务】Azure Cloud Service (Extended Support) 云服务开启诊断日志插件 WAD Extension (Windows Azure Diagnostic) 无法正常工作的原因
    RabbitMQ详解
  • 原文地址:https://blog.csdn.net/aawuwuwuxx/article/details/126907664