• 「MobileNet V3」70 个犬种的图片分类


    ✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
    🍎个人主页:小嗷犬的个人主页
    🍊个人网站:小嗷犬的技术小站
    🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。



    数据集与 Notebook

    数据集:70 Dog Breeds-Image Data Set
    Notebook:「MobileNet V3」70 Dog Breeds-Image Classification


    环境准备

    import warnings
    warnings.filterwarnings('ignore')
    
    • 1
    • 2

    禁用警告,防止干扰。

    !pip install lightning --quiet
    
    • 1

    安装 PyTorch Lightning。

    import random
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    导入常用的库,设置绘图风格。

    import torch
    import torchmetrics
    from torch import nn, optim
    from torch.nn import functional as F
    from torch.utils.data import DataLoader
    from torchvision import transforms, datasets, models
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    导入 PyTorch 相关的库。

    import lightning.pytorch as pl
    from lightning.pytorch.loggers import CSVLogger
    from lightning.pytorch.callbacks.early_stopping import EarlyStopping
    
    • 1
    • 2
    • 3

    导入 PyTorch Lightning 相关的库。

    seed = 1
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    pl.seed_everything(seed, workers=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    设置随机种子。


    数据集

    batch_size = 64
    
    • 1

    设置批次大小。

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    设置数据集的预处理。

    train_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/train", transform=train_transform)
    val_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/valid", transform=test_transform)
    test_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/test", transform=test_transform)
    
    • 1
    • 2
    • 3

    读取数据集。

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    • 1
    • 2
    • 3

    加载数据集。


    可视化

    class_names = train_dataset.classes
    class_count = [train_dataset.targets.count(i) for i in range(len(class_names))]
    df = pd.DataFrame({"Class": class_names, "Count": class_count})
    
    plt.figure(figsize=(12, 20), dpi=100)
    sns.barplot(x="Count", y="Class", data=df)
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    绘制训练集的类别分布。

    训练集的类别分布

    plt.figure(figsize=(12, 20), dpi=100)
    images, labels = next(iter(val_loader))
    for i in range(8):
        ax = plt.subplot(8, 4, i + 1)
        plt.imshow(images[i].permute(1, 2, 0).numpy())
        plt.title(class_names[labels[i]])
        plt.axis("off")
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    绘制训练集的样本。

    训练集的样本


    模型

    class LitModel(pl.LightningModule):
        def __init__(self, num_classes=1000):
            super().__init__()
            self.model = models.mobilenet_v3_large(weights="IMAGENET1K_V2")
            # for param in self.model.parameters():
            #     param.requires_grad = False
            self.model.classifier[3] = nn.Linear(self.model.classifier[3].in_features, num_classes)
            self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
            self.precision = torchmetrics.Precision(task="multiclass", average="macro", num_classes=num_classes)
            self.recall = torchmetrics.Recall(task="multiclass", average="macro", num_classes=num_classes)
            self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
    
        def forward(self, x):
            x = self.model(x)
            return x
    
        def configure_optimizers(self):
            optimizer = optim.Adam(
                self.parameters(), lr=0.001, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
            )
            return optimizer
    
        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.cross_entropy(y_hat, y)
            self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
            self.log_dict(
                {
                    "train_acc": self.accuracy(y_hat, y),
                    "train_prec": self.precision(y_hat, y),
                    "train_recall": self.recall(y_hat, y),
                    "train_f1score": self.f1score(y_hat, y),
                },
                on_step=True,
                on_epoch=False,
                logger=True,
            )
            return loss
    
        def validation_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.cross_entropy(y_hat, y)
            self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)
            self.log_dict(
                {
                    "val_acc": self.accuracy(y_hat, y),
                    "val_prec": self.precision(y_hat, y),
                    "val_recall": self.recall(y_hat, y),
                    "val_f1score": self.f1score(y_hat, y),
                },
                on_step=False,
                on_epoch=True,
                logger=True,
            )
    
        def test_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            self.log_dict(
                {
                    "test_acc": self.accuracy(y_hat, y),
                    "test_prec": self.precision(y_hat, y),
                    "test_recall": self.recall(y_hat, y),
                    "test_f1score": self.f1score(y_hat, y),
                }
            )
    
        def predict_step(self, batch, batch_idx, dataloader_idx=None):
            x, y = batch
            y_hat = self(x)
            preds = torch.argmax(y_hat, dim=1)
            return preds
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74

    定义模型。

    num_classes = len(class_names)
    model = LitModel(num_classes=num_classes)
    logger = CSVLogger("./")
    early_stop_callback = EarlyStopping(
        monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
    )
    trainer = pl.Trainer(
        max_epochs=20,
        enable_progress_bar=True,
        logger=logger,
        callbacks=[early_stop_callback],
        deterministic=True,
    )
    trainer.fit(model, train_loader, val_loader)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    训练模型。

    trainer.test(model, val_loader)
    
    • 1

    测试模型。


    预测

    pred = trainer.predict(model, test_loader)
    pred = torch.cat(pred, dim=0)
    pred = pd.DataFrame(pred.numpy(), columns=["Class"])
    pred["Class"] = pred["Class"].apply(lambda x: class_names[x])
    
    plt.figure(figsize=(12, 20), dpi=100)
    sns.countplot(y="Class", data=pred)
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    绘制预测结果的类别分布。

    预测结果的类别分布


    Loss 与评价指标

    log_path = logger.log_dir + "/metrics.csv"
    metrics = pd.read_csv(log_path)
    x_name = "epoch"
    
    plt.figure(figsize=(8, 6), dpi=100)
    sns.lineplot(x=x_name, y="train_loss", data=metrics, label="Train Loss", linewidth=2, marker="o", markersize=10)
    sns.lineplot(x=x_name, y="val_loss", data=metrics, label="Valid Loss", linewidth=2, marker="X", markersize=12)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()
    
    
    plt.figure(figsize=(14, 12), dpi=100)
    
    plt.subplot(2,2,1)
    sns.lineplot(x=x_name, y="train_acc", data=metrics, label="Train Accuracy", linewidth=2, marker="o", markersize=10)
    sns.lineplot(x=x_name, y="val_acc", data=metrics, label="Valid Accuracy", linewidth=2, marker="X", markersize=12)
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    
    plt.subplot(2,2,2)
    sns.lineplot(x=x_name, y="train_prec", data=metrics, label="Train Precision", linewidth=2, marker="o", markersize=10)
    sns.lineplot(x=x_name, y="val_prec", data=metrics, label="Valid Precision", linewidth=2, marker="X", markersize=12)
    plt.xlabel("Epoch")
    plt.ylabel("Precision")
    
    plt.subplot(2,2,3)
    sns.lineplot(x=x_name, y="train_recall", data=metrics, label="Train Recall", linewidth=2, marker="o", markersize=10)
    sns.lineplot(x=x_name, y="val_recall", data=metrics, label="Valid Recall", linewidth=2, marker="X", markersize=12)
    plt.xlabel("Epoch")
    plt.ylabel("Recall")
    
    plt.subplot(2,2,4)
    sns.lineplot(x=x_name, y="train_f1score", data=metrics, label="Train F1-Score", linewidth=2, marker="o", markersize=10)
    sns.lineplot(x=x_name, y="val_f1score", data=metrics, label="Valid F1-Score", linewidth=2, marker="X", markersize=12)
    plt.xlabel("Epoch")
    plt.ylabel("F1-Score")
    
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    绘制 Loss 与评价指标的变化。

    Loss

    评价指标

  • 相关阅读:
    ORACLE 实现字符串根据条件拼接
    两年经验前端带你重学前端框架必会的ajax+node.js+webpack+git等技术 Day3
    rsa,randon.seed+费马定理
    java流知识小结
    深入解析 Redis 分布式锁原理
    【计算机视觉 | 目标检测】arxiv 计算机视觉关于目标检测的学术速递(9 月 7 日论文合集)
    植物大战僵尸新手攻略(未保存,明天改)
    “深入理解C++类默认成员函数:探索构造、析构与复制“
    撰寫自己的Python C擴展!
    C#语法糖
  • 原文地址:https://blog.csdn.net/qq_63585949/article/details/134537365