码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 深度学习训练营之彩色图片分类


    深度学习训练营

    • 原文链接
    • 环境介绍
    • 前置工作
      • 设置GPU
    • 导入数据
    • 归一化操作
    • 图片可视化
    • 构建CNN网络
    • 进行编译
    • 模型训练
    • 结果可视化
      • 图片展示
      • 对图片的内容进行辨别
    • 模型的精度评估

    原文链接

    • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
    • 🍦 参考文章:365天深度学习训练营-第P1周:实现mnist手写数字识别
    • 🍖 原作者:K同学啊|接辅导、项目定制

    环境介绍

    • 语言环境:Python3.9.13
    • 编译器:jupyter notebook
    • 深度学习环境:TensorFlow2

    前置工作

    设置GPU

    因为本次实验的数据量过大,所有设置多个GPU很有必要

    # K同学啊深度学习练习
    import tensorflow as tf
    gpus = tf.config.list_physical_devices("GPU")
    
    if gpus:
        gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
        tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
        tf.config.set_visible_devices([gpu0],"GPU")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    导入数据

    import tensorflow as tf
    from tensorflow.keras import datasets, layers, models
    import matplotlib.pyplot as plt
    
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
    
    • 1
    • 2
    • 3
    • 4
    • 5

    数据量比较大,下载的时间会比较长
    在这里插入图片描述

    归一化操作

    # 将像素的值标准化至0到1的区间内。
    train_images, test_images = train_images / 255.0, test_images / 255.0
    
    train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    图片可视化

    对于图片的分类进行命名

    [‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’]

    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']
    
    plt.figure(figsize=(20,10))
    for i in range(20):
        plt.subplot(5,10,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i], cmap=plt.cm.binary)
        plt.xlabel(class_names[train_labels[i][0]])
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    在这里插入图片描述

    构建CNN网络

    #设置CNN网络
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), #卷积层1,卷积核3*3
        layers.MaxPooling2D((2, 2)),                   #池化层1,2*2采样
        layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层2,卷积核3*3
        layers.MaxPooling2D((2, 2)),                   #池化层2,2*2采样
        layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层3,卷积核3*3
        
        layers.Flatten(),                      #Flatten层,连接卷积层与全连接层
        layers.Dense(64, activation='relu'),   #全连接层,特征进一步提取
        layers.Dense(10)                       #输出层,输出预期结果
    ])
    
    model.summary()  # 打印网络结构
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    在这里插入图片描述

    进行编译

    进行编译操作

    #进行编译
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    • 1
    • 2
    • 3
    • 4

    模型训练

    epoch设置为10,实现提高精度

    #对模型进行训练
    history = model.fit(train_images, train_labels, epochs=10, 
                        validation_data=(test_images, test_labels))
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    结果可视化

    图片展示

    #图片显示
    plt.imshow(test_images[3])
    
    • 1
    • 2

    在这里插入图片描述

    对图片的内容进行辨别

    import numpy as np
    
    pre = model.predict(test_images)
    print(class_names[np.argmax(pre[3])])
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    模型的精度评估

    #模型评估
    import matplotlib.pyplot as plt
    
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0.5, 1])
    plt.legend(loc='lower right')
    plt.show()
    
    test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    在这里插入图片描述
    计算结果

    print(test_acc)#打印结果
    
    • 1

    在这里插入图片描述

  • 相关阅读:
    网页前端知识汇总(三)——网页前端利用二维码插件qrcode生成在线二维码
    线性筛求欧拉函数前n个和
    Kubernetes:kubelet 源码分析之探针
    为什么说实验室信息管理系统LIMS是势在必行,有哪些必要性
    Linux文件权限
    Java web应用性能分析之【压测工具ab】
    【Python零基础入门篇 · 4】:字符串的运算符、下标和切片
    Qt 5.9.8 安装教程
    【深度思考】:人工智能的发展会带来生产力和生产关系的变革吗?
    HIT_OS_LAB2 调试分析 Linux 0.00 多任务切换
  • 原文地址:https://blog.csdn.net/qq_62904883/article/details/128207433
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号