• VIT 源码详解


    1.项目配置说明

            参数说明:

    数据集:  

           --name cifar10-100_500

           --dataset cifar10

    哪个版本的模型: 

           --model_type ViT-B_16

    预训练权重: 

           --pretrained_dir checkpoint/ViT-B_16.npz

    2.patch embeding与position_embedding

            对于图像编码,以VIT - B/16为例,首先用卷积核大小为16*16、步长为16的卷积,对图像进行变换,此时图像维度变成16 * 768 * 14 * 14,再变换维度为[16, 196, 768],然后将维度为16*1*768的0patch相连。

            对于位置编码,构建一个1 * 197 * 768的向量

            最后,将图像编码与位置编码相加就完成了本次编码。

    代码如下:

    1. class Embeddings(nn.Module):
    2. """Construct the embeddings from patch, position embeddings.
    3. """
    4. def __init__(self, config, img_size, in_channels=3):
    5. super(Embeddings, self).__init__()
    6. self.hybrid = None
    7. img_size = _pair(img_size)
    8. # patch_size 大小 与 patch数量 n_patches
    9. if config.patches.get("grid") is not None:
    10. grid_size = config.patches["grid"]
    11. patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
    12. n_patches = (img_size[0] // 16) * (img_size[1] // 16)
    13. self.hybrid = True
    14. else:
    15. patch_size = _pair(config.patches["size"])
    16. n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
    17. self.hybrid = False
    18. # 使用混合模型
    19. if self.hybrid:
    20. self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
    21. width_factor=config.resnet.width_factor)
    22. in_channels = self.hybrid_model.width * 16
    23. # patch_embeding 16 * 768 * 14 * 14
    24. self.patch_embeddings = Conv2d(in_channels=in_channels,
    25. out_channels=config.hidden_size,
    26. kernel_size=patch_size,
    27. stride=patch_size)
    28. # 初始化 position_embeddings: 1 * 197 * 768
    29. self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
    30. # 初始化第 0 个patch,表示分类特征 1*1*768
    31. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
    32. # dropout层
    33. self.dropout = Dropout(config.transformer["dropout_rate"])
    34. def forward(self, x):
    35. print(x.shape)
    36. B = x.shape[0]
    37. # 拓展cls_tokens的维度:16 *1*768
    38. cls_tokens = self.cls_token.expand(B, -1, -1)
    39. print(cls_tokens.shape)
    40. # 混合模型
    41. if self.hybrid:
    42. x = self.hybrid_model(x)
    43. # 编码:16 * 768 * 14 * 14
    44. x = self.patch_embeddings(x)
    45. print(x.shape)
    46. # 变换维度:16 * 768 * 14 * 14-->[16, 768, 196]
    47. x = x.flatten(2)
    48. print(x.shape)
    49. # [16, 768, 196] --> [16, 196, 768]
    50. x = x.transpose(-1, -2)
    51. print(x.shape)
    52. # 加入分类特征patch
    53. x = torch.cat((cls_tokens, x), dim=1)
    54. print(x.shape)
    55. # 加入位置编码
    56. embeddings = x + self.position_embeddings
    57. print(embeddings.shape)
    58. # dropout层
    59. embeddings = self.dropout(embeddings)
    60. print(embeddings.shape)
    61. return embeddings

    3.ecoder 

    多头注意力模块:

            首先构建q,k,v三个辅助向量,因为我们采用多头注意力机制(12个),首先,我们需要将q,k,v维度从16, 197, 768转换成16, 12, 197, 64,然后获得q,k的相似性qk,因为获得的是两两之间的关系,所以维度为16, 12, 197, 197,消除量纲,经过softmax后,得到提取到的特征向量qkv,维度为16, 12, 197, 64,再将维度还原成16, 197, 768

    1. class Attention(nn.Module):
    2. def __init__(self, config, vis):
    3. super(Attention, self).__init__()
    4. self.vis = vis
    5. # heads数量
    6. self.num_attention_heads = config.transformer["num_heads"]
    7. # 每个head的向量维度
    8. self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
    9. # 总head_size
    10. self.all_head_size = self.num_attention_heads * self.attention_head_size
    11. # query向量
    12. self.query = Linear(config.hidden_size, self.all_head_size)
    13. # key向量
    14. self.key = Linear(config.hidden_size, self.all_head_size)
    15. # value向量
    16. self.value = Linear(config.hidden_size, self.all_head_size)
    17. # 全连接层
    18. self.out = Linear(config.hidden_size, config.hidden_size)
    19. # dropout层
    20. self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
    21. self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
    22. self.softmax = Softmax(dim=-1)
    23. def transpose_for_scores(self, x):
    24. # 维度:16, 197, 768-->16,197,12,64
    25. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    26. # print(new_x_shape)
    27. x = x.view(*new_x_shape)
    28. # print(x.shape)
    29. # print(x.permute(0, 2, 1, 3).shape)
    30. # 16,197,12,64 --> 16, 12, 197, 64
    31. return x.permute(0, 2, 1, 3)
    32. def forward(self, hidden_states):
    33. # print(hidden_states.shape)
    34. # q,k,v:16, 197, 768
    35. mixed_query_layer = self.query(hidden_states)
    36. # print(mixed_query_layer.shape)
    37. mixed_key_layer = self.key(hidden_states)
    38. # print(mixed_key_layer.shape)
    39. mixed_value_layer = self.value(hidden_states)
    40. # print(mixed_value_layer.shape)
    41. # q,k,v:16, 197, 768-->16, 12, 197, 64
    42. query_layer = self.transpose_for_scores(mixed_query_layer)
    43. # print(query_layer.shape)
    44. key_layer = self.transpose_for_scores(mixed_key_layer)
    45. # print(key_layer.shape)
    46. value_layer = self.transpose_for_scores(mixed_value_layer)
    47. # print(value_layer.shape)
    48. # q,k的相似性:16, 12, 197, 197
    49. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    50. # print(attention_scores.shape)
    51. # 消除量纲
    52. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    53. # print(attention_scores.shape)
    54. attention_probs = self.softmax(attention_scores)
    55. # print(attention_probs.shape)
    56. weights = attention_probs if self.vis else None
    57. attention_probs = self.attn_dropout(attention_probs)
    58. # print(attention_probs.shape)
    59. # print(value_layer.shape)
    60. # 特征向量:qkv:16, 12, 197, 64
    61. context_layer = torch.matmul(attention_probs, value_layer)
    62. # print(context_layer.shape)
    63. # 16, 12, 197, 64-->16, 12, 197, 64
    64. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    65. # print(context_layer.shape)
    66. # 16, 12, 197, 64-->16, 197, 768
    67. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    68. context_layer = context_layer.view(*new_context_layer_shape)
    69. # print(context_layer.shape)
    70. # 全连接层:16, 197, 768
    71. attention_output = self.out(context_layer)
    72. # print(attention_output.shape)
    73. # dropout层
    74. attention_output = self.proj_dropout(attention_output)
    75. # print(attention_output.shape)
    76. return attention_output, weights

          transformer encoder 

            对于输入的x,首先经过层归一化后,输入多头注意力机制,对结果进行残差连接,再经过层归一化,经过两层全连接,残差连接后,得到一个模块结果,堆叠L层,输出最终结果 

     

    1. class Block(nn.Module):
    2. def __init__(self, config, vis):
    3. super(Block, self).__init__()
    4. # 序列的大小:768
    5. self.hidden_size = config.hidden_size
    6. # 层归一化
    7. self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
    8. self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
    9. # MLP层
    10. self.ffn = Mlp(config)
    11. # 多头注意力机制
    12. self.attn = Attention(config, vis)
    13. def forward(self, x):
    14. # print(x.shape)
    15. # 16, 197, 768
    16. h = x
    17. # 层归一化
    18. x = self.attention_norm(x)
    19. # print(x.shape)
    20. # 多头注意力机制
    21. x, weights = self.attn(x)
    22. # 残差连接
    23. x = x + h
    24. # print(x.shape)
    25. h = x
    26. # 层归一化
    27. x = self.ffn_norm(x)
    28. # print(x.shape)
    29. # MLP层
    30. x = self.ffn(x)
    31. # print(x.shape)
    32. # 残差连接
    33. x = x + h
    34. # print(x.shape)
    35. return x, weights

    整体架构

            对于输入x,进行patch embeding和position embeding后,此时维度为16*197*768,输入encoder中,经过L层的编码模块,取出第0个patch的编码结果(表示分类特征),输入分类层,得到预测结果。

    1. class VisionTransformer(nn.Module):
    2. def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
    3. super(VisionTransformer, self).__init__()
    4. self.num_classes = num_classes
    5. self.zero_head = zero_head
    6. self.classifier = config.classifier
    7. self.transformer = Transformer(config, img_size, vis)
    8. self.head = Linear(config.hidden_size, num_classes)
    9. def forward(self, x, labels=None):
    10. x, attn_weights = self.transformer(x)
    11. print(x.shape)
    12. # X.shape:16, 197, 768 logits.shape:16, 10
    13. logits = self.head(x[:, 0])
    14. print(logits.shape)
    15. # 交叉熵
    16. if labels is not None:
    17. loss_fct = CrossEntropyLoss()
    18. loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
    19. return loss
    20. else:
    21. return logits, attn_weights

     

     

     

             

     

     

  • 相关阅读:
    FPGA 学习笔记:Verilog 实现LED流水灯控制
    springboot民办高校科研项目管理系统-计算机毕业设计源码54009
    【AXI4 verilog】手把手带你撸AXI代码(四、AXI4接口的RAM设计)
    Redis分布式锁剖析和几种客户端的实现
    Vue中methods实现原理
    [Spring boot] Spring boot 整合RabbitMQ实现通过RabbitMQ进行项目的连接
    k8s 配置存储之 Configmap & secret
    什么是仿射变换?
    unordered_map的键值不能直接用pair;而map 可以使用 pair 作为键值,而不需要额外定义哈希函数
    Linux命令学习—Apache 服务器(上)
  • 原文地址:https://blog.csdn.net/qq_52053775/article/details/126261070