• DAY17-深度学习100例-卷积神经网络(CNN)识别眼睛状态



    前言

    本文通过对人眼状态的识别达到检测注意力的目的。

    一、前期工作

    1.设置GPU

    import tensorflow as tf
    
    gpus = tf.config.list_physical_devices("GPU")
    
    if gpus:
        tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
        tf.config.set_visible_devices([gpus[0]],"GPU")
    
    # 打印显卡信息,确认GPU可用
    print(gpus)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    2.导入数据

    import matplotlib.pyplot as plt
    # 支持中文
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    import os,PIL
    
    # 设置随机种子尽可能使结果可以重现
    import numpy as np
    np.random.seed(1)
    
    # 设置随机种子尽可能使结果可以重现
    import tensorflow as tf
    tf.random.set_seed(1)
    
    import pathlib
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    data_dir = "017_Eye_dataset"
    
    data_dir = pathlib.Path(data_dir)
    
    • 1
    • 2
    • 3

    3.查看数据

    image_count = len(list(data_dir.glob('*/*')))
    
    print("图片总数为:",image_count)
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    二、数据预处理

    1.加载数据

    batch_size = 64
    img_height = 224
    img_width = 224
    
    • 1
    • 2
    • 3
    """
    关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
    """
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    在这里插入图片描述

    """
    关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
    """
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    在这里插入图片描述

    我们可以通过 class_names 输出数据集的标签。

    class_names = train_ds.class_names
    print(class_names)
    
    • 1
    • 2

    在这里插入图片描述

    2.可视化数据

    plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
    plt.suptitle("数据展示")
    
    for images, labels in train_ds.take(1):
        for i in range(8):
            
            ax = plt.subplot(2, 4, i + 1)  
            
            ax.patch.set_facecolor('yellow')
            
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            
            plt.axis("off")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    在这里插入图片描述

    3.再次检查数据

    for image_batch, labels_batch in train_ds:
        print(image_batch.shape)
        print(labels_batch.shape)
        break
    
    • 1
    • 2
    • 3
    • 4
    (64, 224, 224, 3)
    (64,)
    
    • 1
    • 2

    4.配置数据集

    AUTOTUNE = tf.data.AUTOTUNE
    
    train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    
    • 1
    • 2
    • 3
    • 4

    三、调用官方网络模型

    model = tf.keras.applications.VGG16()
    # 打印模型信息
    model.summary()
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    四、设置动态学习率

    首先说明学习率大与学习率小的优缺点。

    学习率大:
    优点:1、加快学习速率;2、有利于跳出局部最优值。
    缺点:1、导致模型训练不收敛;2、仅使用大学习率容易导致模型不精确。

    学习率小:
    优点:1、有助于模型收敛,模型细化;2、提高模型精度。
    缺点:1、很难跳出局部最优值;2、收敛缓慢。

    这里设置的学习率为:指数衰减性。假设1个epoch有100个batch(相当于100 steps),20个epoch后,step==2000,即step会随epoch累加。计算公式如下:
    在这里插入图片描述

    # 设置初始学习率
    initial_learning_rate = 1e-4
    
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate, 
            decay_steps=20,      # 敲黑板!!!这里是指 steps,不是指epochs
            decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
            staircase=True)
    
    # 将指数衰减学习率送入优化器
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    五、编译

    model.compile(optimizer=optimizer,
                  loss     ='sparse_categorical_crossentropy',
                  metrics  =['accuracy'])
    
    • 1
    • 2
    • 3

    六、训练模型

    epochs = 10
    
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    注:在cpu上进行训练,由于内存不够,训练过程中断。

    七、评估模型

    1.Accuracy和Loss图

    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    epochs_range = range(epochs)
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    2.混淆矩阵

    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import pandas as pd
    
    # 定义一个绘制混淆矩阵图的函数
    def plot_cm(labels, predictions):
        
        # 生成混淆矩阵
        conf_numpy = confusion_matrix(labels, predictions)
        # 将矩阵转化为 DataFrame
        conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
        
        plt.figure(figsize=(8,7))
        
        sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
        
        plt.title('混淆矩阵',fontsize=15)
        plt.ylabel('真实值',fontsize=14)
        plt.xlabel('预测值',fontsize=14)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    val_pre   = []
    val_label = []
    
    for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵
        for image, label in zip(images, labels):
            # 需要给图片增加一个维度
            img_array = tf.expand_dims(image, 0) 
            # 使用模型预测图片中的人物
            prediction = model.predict(img_array)
    
            val_pre.append(class_names[np.argmax(prediction)])
            val_label.append(class_names[label])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    plot_cm(val_label, val_pre)
    
    • 1

    八、保存和加载模型

    # 保存模型
    model.save('model/17_model.h5')
    
    • 1
    • 2
    # 加载模型
    new_model = tf.keras.models.load_model('model/17_model.h5')
    
    • 1
    • 2

    九、预测

    # 采用加载的模型(new_model)来看预测结果
    
    plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
    plt.suptitle("预测结果展示")
    
    for images, labels in val_ds.take(1):
        for i in range(8):
            ax = plt.subplot(2, 4, i + 1)  
            
            # 显示图片
            plt.imshow(images[i].numpy().astype("uint8"))
            
            # 需要给图片增加一个维度
            img_array = tf.expand_dims(images[i], 0) 
            
            # 使用模型预测图片中的人物
            predictions = new_model.predict(img_array)
            plt.title(class_names[np.argmax(predictions)])
    
            plt.axis("off")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
  • 相关阅读:
    nginx 集群部署
    国内离线安装 Chrome 扩展程序的方法总结
    513找树左下角值
    JSP SSM 成果展示系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计
    Shiro-12-caching 缓存
    Mysql 45讲学习笔记(七)行锁
    腾讯云弹性公网IP是什么?EIP是什么意思?
    游戏专属i9-13900k服务器配置一个月多少钱
    QT连接数据库
    测开日常积累 —— 自动化测试里的数据驱动和关键字驱动思路的理解
  • 原文地址:https://blog.csdn.net/weixin_44336912/article/details/126334479