• 计算机视觉之Vision Transformer图像分类


    Vision Transformer(ViT)简介

    自注意结构模型的发展,特别是Transformer模型的出现,极大推动了自然语言处理模型的发展。Transformers的计算效率和可扩展性使其能够训练具有超过100B参数的规模空前的模型。ViT是自然语言处理和计算机视觉的结合,能够在图像分类任务上取得良好效果,而不依赖卷积操作。

    Vision Transformer(ViT)简介

    近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

    ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

    模型结构

    ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

    vit-architecture

    模型特点

    ViT模型是一种用于图像分类的模型,将原图像划分为多个图像块,然后将这些图像块转换为一维向量,加上类别向量和位置向量作为模型输入。模型主体采用基于Transformer的Encoder结构,但调整了Normalization的位置,其中最主要的结构是Multi-head Attention。模型在Blocks堆叠后接全连接层,使用类别向量的输出进行分类,通常将全连接层称为Head,Transformer Encoder部分称为backbone。

    Transformer基本原理

    Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

    transformer-architecture

    模型训练

    模型训练前需要设定损失函数、优化器、回调函数等,以及建议根据项目需要调整epoch_size。训练ViT模型需要很长时间,可以通过输出的信息查看训练的进度和指标。

    1. from mindspore.nn import LossBase
    2. from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
    3. from mindspore import train
    4. # define super parameter
    5. epoch_size = 10
    6. momentum = 0.9
    7. num_classes = 1000
    8. resize = 224
    9. step_size = dataset_train.get_dataset_size()
    10. # construct model
    11. network = ViT()
    12. # load ckpt
    13. vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
    14. path = "./ckpt/vit_b_16_224.ckpt"
    15. vit_path = download(vit_url, path, replace=True)
    16. param_dict = ms.load_checkpoint(vit_path)
    17. ms.load_param_into_net(network, param_dict)
    18. # define learning rate
    19. lr = nn.cosine_decay_lr(min_lr=float(0),
    20. max_lr=0.00005,
    21. total_step=epoch_size * step_size,
    22. step_per_epoch=step_size,
    23. decay_epoch=10)
    24. # define optimizer
    25. network_opt = nn.Adam(network.trainable_params(), lr, momentum)
    26. # define loss function
    27. class CrossEntropySmooth(LossBase):
    28. """CrossEntropy."""
    29. def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
    30. super(CrossEntropySmooth, self).__init__()
    31. self.onehot = ops.OneHot()
    32. self.sparse = sparse
    33. self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
    34. self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
    35. self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
    36. def construct(self, logit, label):
    37. if self.sparse:
    38. label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
    39. loss = self.ce(logit, label)
    40. return loss
    41. network_loss = CrossEntropySmooth(sparse=True,
    42. reduction="mean",
    43. smooth_factor=0.1,
    44. num_classes=num_classes)
    45. # set checkpoint
    46. ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
    47. ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
    48. # initialize model
    49. # "Ascend + mixed precision" can improve performance
    50. ascend_target = (ms.get_context("device_target") == "Ascend")
    51. if ascend_target:
    52. model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
    53. else:
    54. model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
    55. # train model
    56. model.train(epoch_size,
    57. dataset_train,
    58. callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
    59. dataset_sink_mode=False,)

    总结

    本案例演示了如何在ImageNet数据集上训练、验证和推断ViT模型。通过讲解ViT模型的关键结构和原理,帮助用户理解Multi-Head Attention、TransformerEncoder和pos_embedding等关键概念。建议用户基于源码深入学习,以更详细地理解ViT模型的原理。

  • 相关阅读:
    Linux 扩展篇 YUM+Shell编程
    LeetCode 每日一题 2023/11/13-2023/11/19
    2020 款丰田雷凌车组合仪表上多个故障灯偶发点亮
    丝绸之路网络安全论坛成功举办,开源网安受邀分享软件供应链安全落地经验
    MySQL 单表查询 多表设计
    大数据Doris(二十五):数据导入演示和其他导入案例
    计算机毕业设计(附源码)python自习室管理系统
    初探C++ CRTP(奇异的递归模板模式)
    UE4光照基础
    Vue3+node.js网易云音乐实战项目(五)
  • 原文地址:https://blog.csdn.net/qq_33816117/article/details/140408114