• 基于tensorflow的ResNet50V2网络识别动物


    前言

    之前很多人在,如何进行XXX的识别,对应的神经网络如何搭建。对应神经网络怎么搭建,我也是照本宣科,只能说看得懂而已,没有对这块进行深入的研究,但是现在tensorflow,paddle这些工具,都提供了非常成熟的神经网络进行直接使用。
    本文对过往的一些文章进行改造,使用已经集成的神经网络,简单的实现多个种类的动物识别。

    环境

    tensorflow:2.9.1
    keras:2.9.0
    os:windows10
    gpu:RTX3070
    cuda:cuda_11.4.r11.4
    如何安装tensorflow就不在做赘述,要重点说明 tensorflow与keras版本的不同会引起不同工具类的使用。

    数据准备

    链接: https://pan.baidu.com/s/1J7yRsTS2o0LcVkbKKJD-Bw 提取码: 6666
    解压之后,结构如下
    在这里插入图片描述

    代码

    一、模型训练代码(animalv2_model_train.py)

    导入
    import os
    
    import plotly.express as px
    import matplotlib.pyplot as plt
    from IPython.display import clear_output as cls
    import numpy as np
    from glob import glob
    import pandas as pd
    
    # Model
    import keras
    from keras.models import Sequential, load_model
    from keras.layers import GlobalAvgPool2D as GAP, Dense, Dropout
    from keras.preprocessing.image import ImageDataGenerator
    
    # Callbacks
    from keras.callbacks import EarlyStopping, ModelCheckpoint
    
    # 模型与处理工具
    import tensorflow as tf
    from tensorflow.keras.applications import ResNet50V2
    from tensorflow.keras.utils import load_img, img_to_array
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    数据集合处理
    root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'
    valid_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/'
    test_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Testing Data/'
    # 动物种类
    class_names = sorted(os.listdir(root_path))
    n_classes = len(class_names)
    
    print(f"Total Number of Classes : {n_classes} \nClass Names : {class_names}")
    
    
    class_dis = [len(os.listdir(root_path+name)) for name in class_names]
    
    fig = px.pie(names=class_names, values=class_dis, title="Training Class Distribution", hole=0.4)
    fig.update_layout({'title':{'x':0.48}})
    fig.show()
    
    fig = px.bar(x=class_names, y=class_dis, title="Training Class Distribution", color=class_names)
    fig.update_layout({'title':{'x':0.48}})
    fig.show()
    
    # 归一化
    train_gen = ImageDataGenerator(rescale=1/255., rotation_range=10, horizontal_flip=True)
    valid_gen = ImageDataGenerator(rescale=1/255.)
    test_gen = ImageDataGenerator(rescale=1/255)
    
    # Load Data
    train_ds = train_gen.flow_from_directory(root_path, class_mode='binary', target_size=(256,256), shuffle=True, batch_size=32)
    valid_ds = valid_gen.flow_from_directory(valid_path, class_mode='binary', target_size=(256,256), shuffle=True, batch_size=32)
    test_ds = test_gen.flow_from_directory(test_path, class_mode='binary', target_size=(256,256), shuffle=True, batch_size=32)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    结果如下:

    Total Number of Classes : 10 
    Class Names : ['Cat', 'Cow', 'Dog', 'Elephant', 'Gorilla', 'Hippo', 'Monkey', 'Panda', 'Tiger', 'Zebra']
    Found 20000 images belonging to 10 classes.
    Found 1000 images belonging to 10 classes.
    Found 1907 images belonging to 10 classes.
    
    • 1
    • 2
    • 3
    • 4
    • 5
    图片展示
    def show_images(GRID=[5, 5], model=None, size=(20, 20), data=train_ds):
        n_rows = GRID[0]
        n_cols = GRID[1]
        n_images = n_cols * n_rows
    
        i = 1
        plt.figure(figsize=size)
        for images, labels in data:
            id = np.random.randint(len(images))
            image, label = images[id], class_names[int(labels[id])]
    
            plt.subplot(n_rows, n_cols, i)
            plt.imshow(image)
    
            if model is None:
                title = f"Class : {label}"
            else:
                pred = class_names[int(np.argmax(model.predict(image[np.newaxis, ...])))]
                title = f"Org : {label}, Pred : {pred}"
                cls()
    
            plt.title(title)
            plt.axis('off')
    
            i += 1
            if i >= (n_images + 1):
                break
    
        plt.tight_layout()
        plt.show()
    
    def load_image(path):
        image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (256,256)), tf.float32)
        return image
    def show_image(image, title=None):
        plt.imshow(image)
        plt.axis('off')
        plt.title(title)
    
    show_images(data=train_ds)
    show_images(data=valid_ds)
    show_images(data=test_ds)
    
    path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Interesting Data/'
    interesting_images = [glob(path + name + "/*") for name in class_names]
    
    # Interesting Cat Images
    for name in class_names:
        plt.figure(figsize=(25, 8))
        cat_interesting = interesting_images[class_names.index(name)]
        for i, i_path in enumerate(cat_interesting):
            name = i_path.split("/")[-1].split(".")[0]
            image = load_image(i_path)
            plt.subplot(1,len(cat_interesting),i+1)
            show_image(image, title=name.title())
        plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    在这里插入图片描述

    模型训练
    with tf.device("/GPU:0"):
        ## 定义网络
        base_model = ResNet50V2(input_shape=(256,256,3), include_top=False)
        base_model.trainable = False
        cls()
    
        # 设计参数
        name = "ResNet50V2"
        model = Sequential([
            base_model,
            GAP(),
            Dense(256, activation='relu', kernel_initializer='he_normal'),
            Dropout(0.2),
            Dense(n_classes, activation='softmax')
        ], name=name)
    
        # Callbacks
        # 容忍度为3,在容忍度之内就结束训练
        cbs = [EarlyStopping(patience=3, restore_best_weights=True), ModelCheckpoint(name + "_V2.h5", save_best_only=True)]
    
        # Model
        opt = tf.keras.optimizers.Adam(learning_rate=2e-3)
        model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
    
        # Model Training
        history = model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=50)
    
    data = pd.DataFrame(history.history)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    模型训练

    运行上面代码,我电脑的配置差不多需要1700+s(PS:可以换一下内存大一些的显卡比如 RTX40XX )
    执行结果为如下:

    2022-11-29 17:43:01.082836: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
    To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
    2022-11-29 17:43:01.449655: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5472 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:01:00.0, compute capability: 8.6
    Epoch 1/50
    2022-11-29 17:43:18.284528: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204
    2022-11-29 17:43:21.378441: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
    625/625 [==============================] - 292s 457ms/step - loss: 0.2227 - accuracy: 0.9361 - val_loss: 0.1201 - val_accuracy: 0.9630
    Epoch 2/50
    625/625 [==============================] - 217s 348ms/step - loss: 0.1348 - accuracy: 0.9596 - val_loss: 0.1394 - val_accuracy: 0.9610
    Epoch 3/50
    625/625 [==============================] - 218s 349ms/step - loss: 0.1193 - accuracy: 0.9641 - val_loss: 0.1452 - val_accuracy: 0.9620
    Epoch 4/50
    625/625 [==============================] - 219s 350ms/step - loss: 0.1035 - accuracy: 0.9690 - val_loss: 0.1147 - val_accuracy: 0.9690
    Epoch 5/50
    625/625 [==============================] - 221s 354ms/step - loss: 0.0897 - accuracy: 0.9736 - val_loss: 0.1117 - val_accuracy: 0.9730
    Epoch 6/50
    625/625 [==============================] - 219s 351ms/step - loss: 0.0817 - accuracy: 0.9747 - val_loss: 0.1347 - val_accuracy: 0.9640
    Epoch 7/50
    625/625 [==============================] - 219s 351ms/step - loss: 0.0818 - accuracy: 0.9740 - val_loss: 0.1126 - val_accuracy: 0.9700
    Epoch 8/50
    625/625 [==============================] - 219s 350ms/step - loss: 0.0731 - accuracy: 0.9785 - val_loss: 0.1366 - val_accuracy: 0.9680
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    验证模型

    验证模型代码(animalv2_model_evaluate.py)
    from keras.models import load_model
    import tensorflow as tf
    from tensorflow.keras.utils import load_img, img_to_array
    import numpy as np
    import os
    
    import matplotlib.pyplot as plt
    
    root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'
    
    class_names = sorted(os.listdir(root_path))
    
    model = load_model('./ResNet50V2_V2.h5')
    model.summary()
    
    def load_image(path):
        image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (256,256)), tf.float32)
        return image
    
    i_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Gorilla/Gorilla (3).jpeg'
    image = load_image(i_path)
    
    preds = model.predict(image[np.newaxis, ...])[0]
    
    print(preds)
    
    pred_class = class_names[np.argmax(preds)]
    
    confidence_score = np.round(preds[np.argmax(preds)], 2)
    
    # Configure Title
    title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
    print(title)
    
    plt.figure(figsize=(25, 8))
    plt.title(title)
    plt.imshow(image)
    plt.show()
    
    while True:
        path =  input("input:")
        if (path == "q!"):
            exit()
        image = load_image(path)
    
        preds = model.predict(image[np.newaxis, ...])[0]
        print(preds)
    
        pred_class = class_names[np.argmax(preds)]
    
        confidence_score = np.round(preds[np.argmax(preds)], 2)
    
        # Configure Title
        title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
        print(title)
    
        plt.figure(figsize=(25, 8))
        plt.title(title)
        plt.imshow(image)
        plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    验证结果
    Model: "ResNet50V2"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     resnet50v2 (Functional)     (None, 8, 8, 2048)        23564800  
                                                                     
     global_average_pooling2d (G  (None, 2048)             0         
     lobalAveragePooling2D)                                          
                                                                     
     dense (Dense)               (None, 256)               524544    
                                                                     
     dropout (Dropout)           (None, 256)               0         
                                                                     
     dense_1 (Dense)             (None, 10)                2570      
                                                                     
    =================================================================
    Total params: 24,091,914
    Trainable params: 527,114
    Non-trainable params: 23,564,800
    _________________________________________________________________
    2022-11-29 20:33:15.981925: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204
    2022-11-29 20:33:18.070138: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
    1/1 [==============================] - 3s 3s/step
    [1.2199847e-09 1.0668253e-12 6.8980124e-13 1.0352933e-08 9.9999988e-01
     4.1255888e-09 7.1100374e-08 3.0439090e-10 3.1216061e-11 2.8051938e-12]
    Pred : Gorilla
    Confidence : 1.0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    做了一个input的能力,可以通过本地的图片地址进行验证

    在这里插入图片描述

    input:./animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Zebra/Zebra-Valid (276).jpeg
    1/1 [==============================] - 0s 21ms/step
    [1.5658158e-12 1.6018555e-10 9.6812911e-13 6.2212702e-10 5.4042397e-09
     5.8055113e-05 4.7865592e-12 3.4024495e-12 3.0037000e-08 9.9994195e-01]
    Pred : Zebra
    Confidence : 1.0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
  • 相关阅读:
    【Transformer系列】深入浅出理解Attention和Self-Attention机制
    基于 outline 实现头像剪裁以及预览
    基于GPT搭建私有知识库聊天机器人(六)仿chatGPT打字机效果
    Kubernetes基础(五)-Service
    vue 的报告页面,生成pdf,使用html2canvas , 下载pdf格式文件。多页分页下载
    【css】sass中的模块化
    C++学习——类成员的访问权限、类的封装
    计算机视觉与深度学习 | SLAM国内外研究现状
    2022-08-10 第四小组 修身课 学习笔记(every day)
    01- mysql基础
  • 原文地址:https://blog.csdn.net/sinom21/article/details/128094714