• 深度学习100例 —— 卷积神经网络(CNN)识别眼睛状态


    活动地址:CSDN21天学习挑战赛

    深度学习100例——卷积神经网络(CNN)识别眼睛状态

    我的环境

    在这里插入图片描述

    1. 前期准备工作

    1.1 设置GPU
    import tensorflow as tf
    
    gpus = tf.config.list_physical_devices("GPU")
    
    if gpus:
        tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
        tf.config.set_visible_devices([gpus[0]],"GPU")
    
    # 打印显卡信息,确认GPU可用
    print(gpus)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    在这里插入图片描述

    1.2 导入数据
    import matplotlib.pyplot as plt
    # 支持中文
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    import os,PIL
    
    # 设置随机种子尽可能使结果可以重现
    import numpy as np
    np.random.seed(1)
    
    # 设置随机种子尽可能使结果可以重现
    import tensorflow as tf
    tf.random.set_seed(1)
    
    import pathlib
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    data_dir = "第17天/017_Eye_dataset"
    data_dir = pathlib.Path(data_dir)
    
    • 1
    • 2
    1.3 查看数据
    image_count = len(list(data_dir.glob('*/*')))
    
    print("图片总数为:",image_count)
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    2. 数据预处理

    2.1 加载数据
    batch_size = 64
    img_height = 224
    img_width = 224
    
    
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    在这里插入图片描述

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

    class_names = train_ds.class_names
    print(class_names)
    
    • 1
    • 2

    在这里插入图片描述

    2.2 可视化数据
    plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
    plt.suptitle("datashow")
    
    for images, labels in train_ds.take(1):
        for i in range(8):
            
            ax = plt.subplot(2, 4, i + 1)  
            
            ax.patch.set_facecolor('yellow')
            
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            
            plt.axis("off")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    在这里插入图片描述

    2.3 再次检查数据
    for image_batch, labels_batch in train_ds:
        print(image_batch.shape)
        print(labels_batch.shape)
        break
    
    • 1
    • 2
    • 3
    • 4
    • Image_batch是形状的张量(8,224,224,3)。这是一批形状240x240x3的8张图片(最后一维指的是彩色通道RGB)。
    • Label_batch是形状(8,)的张量,这些标签对应8张图片
    2.4 配置数据集

    shuffle():打乱数据。

    prefetch():预取数据,加速运行。

    cache():将数据集缓存到内存当中,加速运行。

    AUTOTUNE = tf.data.AUTOTUNE
    
    train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    
    • 1
    • 2
    • 3
    • 4

    3. 调用官方模型

    model = tf.keras.applications.VGG16()
    # 打印模型信息
    model.summary()
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    4. 设置动态学习率

    • 学习率大
      • 优点:1、加快学习速率。2、有助于跳出局部最优值。
      • 缺点:1、导致模型训练不收敛。2、单单使用大学习率容易导致模型不精确。
    • 学习率小
      • 优点:1、有助于模型收敛、模型细化。2、提高模型精度。
      • 缺点:1、很难跳出局部最优值。2、收敛缓慢。

    注意:这里设置的动态学习率为:指数衰减型(ExponentiaIDecay)。假设1个epoch有100个batch(相当于100step),20个epoch过后,step==2000,即step会随着epoch累加。计算公式如下:

    learning _rate = initial_learning_rate *decay_rate ^(step / decay_steps)

    # 设置初始学习率
    initial_learning_rate = 1e-4
    
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate, 
            decay_steps=20,      # 敲黑板!!!这里是指 steps,不是指epochs
            decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
            staircase=True)
    
    # 将指数衰减学习率送入优化器
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    5. 编译

    • 损失函数(loss):用于衡量模型在训练期间的准确率。
    • 优化器(optimizer) :决定模型如何根据其看到的数据和自身的损失函数进行更新。
    • 评价函数(metrics) :用于监控训练和测试步骤。以下示例使用了准确率。即被正确分类的图像的比率。
    model.compile(optimizer=optimizer,loss ='sparse_categorical_crossentropy',metrics = ['accuracy'])
    
    • 1

    6. 训练模型

    epochs = 10
    
    history = model.fit(train_ds,validation_data=val_ds,epochs=epochs)
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    7. 模型评估

    7.1 Accuracy图与Loss图
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    epochs_range = range(epochs)
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    在这里插入图片描述

    7.2 混淆矩阵

    Seaborn是一个画图库,它基于Matplotlib核心库进行了更高阶的API封装,可以让你轻松地画出更漂亮的图形。Seaborn 的漂亮主要体现在配色更加舒服、以及图形元素的样式更加细腻。

    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import pandas as pd
    
    # 定义一个绘制混淆矩阵图的函数
    def plot_cm(labels, predictions):
        
        # 生成混淆矩阵
        conf_numpy = confusion_matrix(labels, predictions)
        # 将矩阵转化为 DataFrame
        conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
        
        plt.figure(figsize=(8,7))
        
        sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
        
        plt.title('hunxiaojuzhen',fontsize=15)
        plt.ylabel('real',fontsize=14)
        plt.xlabel('predict',fontsize=14)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    val_pre   = []
    val_label = []
    
    for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵
        for image, label in zip(images, labels):
            # 需要给图片增加一个维度
            img_array = tf.expand_dims(image, 0) 
            # 使用模型预测图片中的人物
            prediction = model.predict(img_array)
    
            val_pre.append(class_names[np.argmax(prediction)])
            val_label.append(class_names[label])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    plot_cm(val_label, val_pre)
    
    • 1

    在这里插入图片描述

    8. 保存和加载模型

    # 保存模型
    model.save('model/17_model.h5')
    # 加载模型
    new_model = tf.keras.models.load_model('model/17_model.h5')
    
    • 1
    • 2
    • 3
    • 4

    9. 预测

    # 采用加载的模型(new_model)来看预测结果
    
    plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
    plt.suptitle("result_predict")
    
    for images, labels in val_ds.take(1):
        for i in range(8):
            ax = plt.subplot(2, 4, i + 1)  
            
            # 显示图片
            plt.imshow(images[i].numpy().astype("uint8"))
            
            # 需要给图片增加一个维度
            img_array = tf.expand_dims(images[i], 0) 
            
            # 使用模型预测图片中的人物
            predictions = new_model.predict(img_array)
            plt.title(class_names[np.argmax(predictions)])
    
            plt.axis("off")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    在这里插入图片描述

  • 相关阅读:
    小试单一职责原则
    【Java】Arrays类、static关键字
    [项目管理-17]:技术研发、项目管理、部门管理三种角色的详细比较
    LeetCode 每日一题——623. 在二叉树中增加一行
    MySQL怎么为表添加描述
    Cannot download sources:IDEA源码无法下载
    uniapp开发手机APP、H5网页、微信小程序、长列表插件
    最全!2024百度Spring Zuul面试题大全,详解每个角落,面试必备宝典!收藏版!
    309 买卖股票的最佳时机含冷冻期(状态机DP)(灵神笔记)
    android NDK 开发包,网盘下载,不限速
  • 原文地址:https://blog.csdn.net/weixin_44226181/article/details/126277907