• YOLOv5、YOLOv8改进:Swin Transformer-V2


    1.介绍

    论文地址:https://arxiv.org/abs/2111.09883

    综述

    该论文作者提出了缩放 Swin Transformer 的技术 多达 30 亿个参数,使其能够使用多达 1,536 个图像进行训练1,536 分辨率。通过扩大容量和分辨率,Swin Transformer 在四个具有代表性的视觉基准上创造了新记录:ImageNet-V2 图像分类的84.0% top-1 准确率,COCO 对象检测的63.1 / 54.4 box / mask mAP,ADE20K 语义分割的59.9 mIoU,和86.8%Kinetics-400 视频动作分类的前 1 准确率。我们的技术通常适用于扩大视觉模型,但尚未像 NLP 语言模型那样被广泛探索,部分原因是在训练和应用方面存在以下困难:1)视觉模型经常面临大规模的不稳定性问题和 2)许多下游视觉任务需要高分辨率图像或窗口,目前尚不清楚如何有效地将低分辨率预训练的模型转移到更高分辨率的模型。当图像分辨率很高时,GPU 内存消耗也是一个问题。为了解决这些问题,我们提出了几种技术,并通过使用 Swin Transformer 作为案例研究来说明:1)后归一化技术和缩放余弦注意方法,以提高大型视觉模型的稳定性;2) 一种对数间隔的连续位置偏差技术,可有效地将在低分辨率图像和窗口上预训练的模型转移到其更高分辨率的对应物上。此外,我们分享了我们的关键实现细节,这些细节可以显着节省 GPU 内存消耗,从而使使用常规 GPU 训练大型视觉模型变得可行。使用这些技术和自我监督的预训练,我们成功训练了一个强大的 30 亿个 Swin Transformer 模型,并有效地将其转移到涉及高分辨率图像或窗口的各种视觉任务中,在各种的基准。代码将在 我们分享了我们的关键实现细节,这些细节可以显着节省 GPU 内存消耗,从而使使用常规 GPU 训练大型视觉模型变得可行。使用这些技术和自我监督的预训练,我们成功训练了一个强大的 30 亿个 Swin Transformer 模型,并有效地将其转移到涉及高分辨率图像或窗口的各种视觉任务中,在各种的基准。代码将在 我们分享了我们的关键实现细节,这些细节可以显着节省 GPU 内存消耗,从而使使用常规 GPU 训练大型视觉模型变得可行。使用这些技术和自我监督的预训练,我们成功训练了一个强大的 30 亿个 Swin Transformer 模型,并有效地将其转移到涉及高分辨率图像或窗口的各种视觉任务中,在各种的基准。代码将在 我们成功训练了一个强大的 30 亿个 Swin Transformer 模型,并将其有效地转移到涉及高分辨率图像或窗口的各种视觉任务中,在各种基准测试中达到了最先进的精度。代码将在 我们成功训练了一个强大的 30 亿个 Swin Transformer 模型,并将其有效地转移到涉及高分辨率图像或窗口的各种视觉任务中,在各种基准测试中达到了最先进的精度。

    要解决的问题
    视觉模型通常面临尺度不稳定问题;

    下游任务需要高分辨率图像,尚不明确如何将低分辨率预训练模型迁移为高分辨率版本 ;当图像分辨率非常大时,GPU显存占用也是个问题。

    改进方案
    提出后规范化(Post Normalization)技术与可缩放(Scaled)cosine注意力提升大视觉模型的稳定性;
    提出log空间连续位置偏置技术进行低分辨率预训练模型向高分辨率模型迁移;
    我们还共享了至关重要的实现细节 ,它可以大幅节省GPU显存占用以使得大视觉模型训练变得可行。

    2.  YOLOv5改进方法

    2.1 YOLOv5的yaml配置文件

    首先增加以下yolov5_swin_transfomrer.yaml文件

    1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
    2. # Parameters
    3. nc: 80 # number of classes
    4. depth_multiple: 0.33 # model depth multiple
    5. width_multiple: 0.50 # layer channel multiple
    6. anchors:
    7. - [10,13, 16,30, 33,23] # P3/8
    8. - [30,61, 62,45, 59,119] # P4/16
    9. - [116,90, 156,198, 373,326] # P5/32
    10. # YOLOv5 v6.0 backbone by yoloair
    11. backbone:
    12. # [from, number, module, args]
    13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
    14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
    15. [-1, 3, C3, [128]],
    16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
    17. [-1, 6, C3, [256]],
    18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
    19. [-1, 9, SwinV2_CSPB, [256, 256]],
    20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
    21. [-1, 3, SwinV2_CSPB, [512, 512]], # 9 <--- ST2CSPB() Transformer module
    22. [-1, 1, SPPF, [1024, 5]], # 9
    23. ]
    24. # YOLOv5 v6.0 head
    25. head:
    26. [[-1, 1, Conv, [512, 1, 1]],
    27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
    28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
    29. [-1, 3, C3, [512, False]], # 13
    30. [-1, 1, Conv, [256, 1, 1]],
    31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
    32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
    33. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
    34. [-1, 1, Conv, [256, 3, 2]],
    35. [[-1, 14], 1, Concat, [1]], # cat head P4
    36. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
    37. [-1, 1, Conv, [512, 3, 2]],
    38. [[-1, 10], 1, Concat, [1]], # cat head P5
    39. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
    40. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
    41. ]

    2.2 common.py配置

    在./models/common.py文件中增加以下模块,直接复制即可

    1. class WindowAttention_v2(nn.Module):
    2. def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
    3. pretrained_window_size=[0, 0]):
    4. super().__init__()
    5. self.dim = dim
    6. self.window_size = window_size # Wh, Ww
    7. self.pretrained_window_size = pretrained_window_size
    8. self.num_heads = num_heads
    9. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
    10. # mlp to generate continuous relative position bias
    11. self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
    12. nn.ReLU(inplace=True),
    13. nn.Linear(512, num_heads, bias=False))
    14. # get relative_coords_table
    15. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
    16. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
    17. relative_coords_table = torch.stack(
    18. torch.meshgrid([relative_coords_h,
    19. relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
    20. if pretrained_window_size[0] > 0:
    21. relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
    22. relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
    23. else:
    24. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
    25. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
    26. relative_coords_table *= 8 # normalize to -8, 8
    27. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
    28. torch.abs(relative_coords_table) + 1.0) / np.log2(8)
    29. self.register_buffer("relative_coords_table", relative_coords_table)
    30. # get pair-wise relative position index for each token inside the window
    31. coords_h = torch.arange(self.window_size[0])
    32. coords_w = torch.arange(self.window_size[1])
    33. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    34. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    35. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
    36. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
    37. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    38. relative_coords[:, :, 1] += self.window_size[1] - 1
    39. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    40. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    41. self.register_buffer("relative_position_index", relative_position_index)
    42. self.qkv = nn.Linear(dim, dim * 3, bias=False)
    43. if qkv_bias:
    44. self.q_bias = nn.Parameter(torch.zeros(dim))
    45. self.v_bias = nn.Parameter(torch.zeros(dim))
    46. else:
    47. self.q_bias = None
    48. self.v_bias = None
    49. self.attn_drop = nn.Dropout(attn_drop)
    50. self.proj = nn.Linear(dim, dim)
    51. self.proj_drop = nn.Dropout(proj_drop)
    52. self.softmax = nn.Softmax(dim=-1)
    53. def forward(self, x, mask=None):
    54. B_, N, C = x.shape
    55. qkv_bias = None
    56. if self.q_bias is not None:
    57. qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
    58. qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
    59. qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
    60. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
    61. # cosine attention
    62. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
    63. logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
    64. attn = attn * logit_scale
    65. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
    66. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
    67. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
    68. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    69. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
    70. attn = attn + relative_position_bias.unsqueeze(0)
    71. if mask is not None:
    72. nW = mask.shape[0]
    73. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    74. attn = attn.view(-1, self.num_heads, N, N)
    75. attn = self.softmax(attn)
    76. else:
    77. attn = self.softmax(attn)
    78. attn = self.attn_drop(attn)
    79. try:
    80. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    81. except:
    82. x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
    83. x = self.proj(x)
    84. x = self.proj_drop(x)
    85. return x
    86. def extra_repr(self) -> str:
    87. return f'dim={self.dim}, window_size={self.window_size}, ' \
    88. f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
    89. def flops(self, N):
    90. # calculate flops for 1 window with token length of N
    91. flops = 0
    92. # qkv = self.qkv(x)
    93. flops += N * self.dim * 3 * self.dim
    94. # attn = (q @ k.transpose(-2, -1))
    95. flops += self.num_heads * N * (self.dim // self.num_heads) * N
    96. # x = (attn @ v)
    97. flops += self.num_heads * N * N * (self.dim // self.num_heads)
    98. # x = self.proj(x)
    99. flops += N * self.dim * self.dim
    100. return flops
    101. class Mlp_v2(nn.Module):
    102. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
    103. super().__init__()
    104. out_features = out_features or in_features
    105. hidden_features = hidden_features or in_features
    106. self.fc1 = nn.Linear(in_features, hidden_features)
    107. self.act = act_layer()
    108. self.fc2 = nn.Linear(hidden_features, out_features)
    109. self.drop = nn.Dropout(drop)
    110. def forward(self, x):
    111. x = self.fc1(x)
    112. x = self.act(x)
    113. x = self.drop(x)
    114. x = self.fc2(x)
    115. x = self.drop(x)
    116. return x
    117. # add 2 functions
    118. class SwinTransformerLayer_v2(nn.Module):
    119. def __init__(self, dim, num_heads, window_size=7, shift_size=0,
    120. mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
    121. act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
    122. super().__init__()
    123. self.dim = dim
    124. #self.input_resolution = input_resolution
    125. self.num_heads = num_heads
    126. self.window_size = window_size
    127. self.shift_size = shift_size
    128. self.mlp_ratio = mlp_ratio
    129. #if min(self.input_resolution) <= self.window_size:
    130. # # if window size is larger than input resolution, we don't partition windows
    131. # self.shift_size = 0
    132. # self.window_size = min(self.input_resolution)
    133. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
    134. self.norm1 = norm_layer(dim)
    135. self.attn = WindowAttention_v2(
    136. dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
    137. qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
    138. pretrained_window_size=(pretrained_window_size, pretrained_window_size))
    139. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    140. self.norm2 = norm_layer(dim)
    141. mlp_hidden_dim = int(dim * mlp_ratio)
    142. self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    143. def create_mask(self, H, W):
    144. # calculate attention mask for SW-MSA
    145. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
    146. h_slices = (slice(0, -self.window_size),
    147. slice(-self.window_size, -self.shift_size),
    148. slice(-self.shift_size, None))
    149. w_slices = (slice(0, -self.window_size),
    150. slice(-self.window_size, -self.shift_size),
    151. slice(-self.shift_size, None))
    152. cnt = 0
    153. for h in h_slices:
    154. for w in w_slices:
    155. img_mask[:, h, w, :] = cnt
    156. cnt += 1
    157. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
    158. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
    159. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    160. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    161. return attn_mask
    162. def forward(self, x):
    163. # reshape x[b c h w] to x[b l c]
    164. _, _, H_, W_ = x.shape
    165. Padding = False
    166. if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
    167. Padding = True
    168. # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
    169. pad_r = (self.window_size - W_ % self.window_size) % self.window_size
    170. pad_b = (self.window_size - H_ % self.window_size) % self.window_size
    171. x = F.pad(x, (0, pad_r, 0, pad_b))
    172. # print('2', x.shape)
    173. B, C, H, W = x.shape
    174. L = H * W
    175. x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
    176. # create mask from init to forward
    177. if self.shift_size > 0:
    178. attn_mask = self.create_mask(H, W).to(x.device)
    179. else:
    180. attn_mask = None
    181. shortcut = x
    182. x = x.view(B, H, W, C)
    183. # cyclic shift
    184. if self.shift_size > 0:
    185. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    186. else:
    187. shifted_x = x
    188. # partition windows
    189. x_windows = window_partition_v2(shifted_x, self.window_size) # nW*B, window_size, window_size, C
    190. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
    191. # W-MSA/SW-MSA
    192. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
    193. # merge windows
    194. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
    195. shifted_x = window_reverse_v2(attn_windows, self.window_size, H, W) # B H' W' C
    196. # reverse cyclic shift
    197. if self.shift_size > 0:
    198. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    199. else:
    200. x = shifted_x
    201. x = x.view(B, H * W, C)
    202. x = shortcut + self.drop_path(self.norm1(x))
    203. # FFN
    204. x = x + self.drop_path(self.norm2(self.mlp(x)))
    205. x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
    206. if Padding:
    207. x = x[:, :, :H_, :W_] # reverse padding
    208. return x
    209. def extra_repr(self) -> str:
    210. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
    211. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
    212. def flops(self):
    213. flops = 0
    214. H, W = self.input_resolution
    215. # norm1
    216. flops += self.dim * H * W
    217. # W-MSA/SW-MSA
    218. nW = H * W / self.window_size / self.window_size
    219. flops += nW * self.attn.flops(self.window_size * self.window_size)
    220. # mlp
    221. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
    222. # norm2
    223. flops += self.dim * H * W
    224. return flops
    225. class SwinTransformer2Block(nn.Module):
    226. def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
    227. super().__init__()
    228. self.conv = None
    229. if c1 != c2:
    230. self.conv = Conv(c1, c2)
    231. # remove input_resolution
    232. self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
    233. shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
    234. def forward(self, x):
    235. if self.conv is not None:
    236. x = self.conv(x)
    237. x = self.blocks(x)
    238. return x
    239. class SwinV2_CSPB(nn.Module):
    240. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    241. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
    242. super(SwinV2_CSPB, self).__init__()
    243. c_ = int(c2) # hidden channels
    244. self.cv1 = Conv(c1, c_, 1, 1)
    245. self.cv2 = Conv(c_, c_, 1, 1)
    246. self.cv3 = Conv(2 * c_, c2, 1, 1)
    247. num_heads = c_ // 32
    248. self.m = SwinTransformer2Block(c_, c_, num_heads, n)
    249. #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
    250. def forward(self, x):
    251. x1 = self.cv1(x)
    252. y1 = self.m(x1)
    253. y2 = self.cv2(x1)
    254. return self.cv3(torch.cat((y1, y2), dim=1))

    2.3 yolo.py配置

    不需要

    修改完成

  • 相关阅读:
    将自定义 GitHub 徽章添加到您的代码库
    数据化运营15 活跃(上):如何通过运营手法提升⽤户活跃度?
    HK32F030MF4P6 红外遥控接收例程
    Git 的暂存区(staging area)理解
    linux之cpu模拟负载程序
    同步辐射散射测试中影响效果的原因有哪些?
    互斥量互斥锁
    Origin中如何上标?R\+(2)即 — R的2上标
    element-plus打开Dialog、图片预览导致页面抖动
    伺服丝杠系统常用运算功能块
  • 原文地址:https://blog.csdn.net/weixin_45303602/article/details/133216041