• YOLOv8+swin_transfomerv2


    测试环境:cuda11.3  pytorch1.11 rtx3090  wsl2 ubuntu20.04

    踩了很多坑,网上很多博主的代码根本跑不通,自己去github仓库复现修改的

    网上博主的代码日常出现cpu,gpu混合,或许是人家分布式训练了,哈哈哈

    下面上干货吧,宝子们点个关注,点个赞,没有废话

    yolov8_yaml文件修改(目标检测就修改

    1. # Ultralytics YOLO 🚀, AGPL-3.0 license
    2. # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
    3. # Parameters
    4. nc: 1 # number of classes
    5. scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
    6. # [depth, width, max_channels]
    7. n: [0.33, 0.25, 1024]
    8. s: [0.33, 0.50, 1024]
    9. m: [0.67, 0.75, 768]
    10. l: [1.00, 1.00, 512]
    11. x: [1.00, 1.25, 512]
    12. # YOLOv8.0n backbone
    13. backbone:
    14. # [from, repeats, module, args]
    15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
    16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
    17. - [-1, 3, C2f, [128, True]]
    18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
    19. - [-1, 6, C2f, [256, True]]
    20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
    21. - [-1, 9, SwinV2_CSPB, [512, 512]]
    22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
    23. - [-1, 3, SwinV2_CSPB, [1024, 1024]]
    24. - [-1, 1, SPPF, [1024, 5]] # 9
    25. # YOLOv8.0n head
    26. head:
    27. - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
    28. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
    29. - [-1, 3, C2f, [512]] # 12
    30. - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
    31. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
    32. - [-1, 3, C2f, [256]] # 15 (P3/8-small)
    33. - [-1, 1, Conv, [256, 3, 2]]
    34. - [[-1, 12], 1, Concat, [1]] # cat head P4
    35. - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
    36. - [-1, 1, Conv, [512, 3, 2]]
    37. - [[-1, 9], 1, Concat, [1]] # cat head P5
    38. - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
    39. - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)

    在nn/modules/block.py最下面加入

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
    5. from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
    6. from .transformer import TransformerBlock
    7. import numpy as np
    8. class WindowAttention(nn.Module):
    9. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
    10. super().__init__()
    11. self.dim = dim
    12. self.window_size = window_size # Wh, Ww
    13. self.num_heads = num_heads
    14. head_dim = dim // num_heads
    15. self.scale = qk_scale or head_dim ** -0.5
    16. # define a parameter table of relative position bias
    17. self.relative_position_bias_table = nn.Parameter(
    18. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
    19. # get pair-wise relative position index for each token inside the window
    20. coords_h = torch.arange(self.window_size[0])
    21. coords_w = torch.arange(self.window_size[1])
    22. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    23. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    24. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
    25. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
    26. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    27. relative_coords[:, :, 1] += self.window_size[1] - 1
    28. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    29. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    30. self.register_buffer("relative_position_index", relative_position_index)
    31. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    32. self.attn_drop = nn.Dropout(attn_drop)
    33. self.proj = nn.Linear(dim, dim)
    34. self.proj_drop = nn.Dropout(proj_drop)
    35. nn.init.normal_(self.relative_position_bias_table, std=.02)
    36. self.softmax = nn.Softmax(dim=-1)
    37. def forward(self, x, mask=None):
    38. B_, N, C = x.shape
    39. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    40. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
    41. q = q * self.scale
    42. attn = (q @ k.transpose(-2, -1))
    43. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
    44. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
    45. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    46. attn = attn + relative_position_bias.unsqueeze(0)
    47. if mask is not None:
    48. nW = mask.shape[0]
    49. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    50. attn = attn.view(-1, self.num_heads, N, N)
    51. attn = self.softmax(attn)
    52. else:
    53. attn = self.softmax(attn)
    54. attn = self.attn_drop(attn)
    55. # print(attn.dtype, v.dtype)
    56. try:
    57. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    58. except:
    59. # print(attn.dtype, v.dtype)
    60. x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
    61. x = self.proj(x)
    62. x = self.proj_drop(x)
    63. return x
    64. class WindowAttention(nn.Module):
    65. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
    66. super().__init__()
    67. self.dim = dim
    68. self.window_size = window_size # Wh, Ww
    69. self.num_heads = num_heads
    70. head_dim = dim // num_heads
    71. self.scale = qk_scale or head_dim ** -0.5
    72. # define a parameter table of relative position bias
    73. self.relative_position_bias_table = nn.Parameter(
    74. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
    75. # get pair-wise relative position index for each token inside the window
    76. coords_h = torch.arange(self.window_size[0])
    77. coords_w = torch.arange(self.window_size[1])
    78. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    79. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    80. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
    81. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
    82. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    83. relative_coords[:, :, 1] += self.window_size[1] - 1
    84. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    85. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    86. self.register_buffer("relative_position_index", relative_position_index)
    87. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    88. self.attn_drop = nn.Dropout(attn_drop)
    89. self.proj = nn.Linear(dim, dim)
    90. self.proj_drop = nn.Dropout(proj_drop)
    91. nn.init.normal_(self.relative_position_bias_table, std=.02)
    92. self.softmax = nn.Softmax(dim=-1)
    93. def forward(self, x, mask=None):
    94. B_, N, C = x.shape
    95. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    96. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
    97. q = q * self.scale
    98. attn = (q @ k.transpose(-2, -1))
    99. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
    100. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
    101. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    102. attn = attn + relative_position_bias.unsqueeze(0)
    103. if mask is not None:
    104. nW = mask.shape[0]
    105. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    106. attn = attn.view(-1, self.num_heads, N, N)
    107. attn = self.softmax(attn)
    108. else:
    109. attn = self.softmax(attn)
    110. attn = self.attn_drop(attn)
    111. # print(attn.dtype, v.dtype)
    112. try:
    113. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    114. except:
    115. # print(attn.dtype, v.dtype)
    116. x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
    117. x = self.proj(x)
    118. x = self.proj_drop(x)
    119. return x
    120. class WindowAttention_v2(nn.Module):
    121. def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
    122. pretrained_window_size=[0, 0]):
    123. super().__init__()
    124. self.dim = dim
    125. self.window_size = window_size # Wh, Ww
    126. self.pretrained_window_size = pretrained_window_size
    127. self.num_heads = num_heads
    128. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
    129. # mlp to generate continuous relative position bias
    130. self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
    131. nn.ReLU(inplace=True),
    132. nn.Linear(512, num_heads, bias=False))
    133. # get relative_coords_table
    134. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
    135. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
    136. relative_coords_table = torch.stack(
    137. torch.meshgrid([relative_coords_h,
    138. relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
    139. if pretrained_window_size[0] > 0:
    140. relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
    141. relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
    142. else:
    143. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
    144. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
    145. relative_coords_table *= 8 # normalize to -8, 8
    146. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
    147. torch.abs(relative_coords_table) + 1.0) / np.log2(8)
    148. self.register_buffer("relative_coords_table", relative_coords_table)
    149. # get pair-wise relative position index for each token inside the window
    150. coords_h = torch.arange(self.window_size[0])
    151. coords_w = torch.arange(self.window_size[1])
    152. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    153. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    154. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
    155. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
    156. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    157. relative_coords[:, :, 1] += self.window_size[1] - 1
    158. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    159. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    160. self.register_buffer("relative_position_index", relative_position_index)
    161. self.qkv = nn.Linear(dim, dim * 3, bias=False)
    162. if qkv_bias:
    163. self.q_bias = nn.Parameter(torch.zeros(dim))
    164. self.v_bias = nn.Parameter(torch.zeros(dim))
    165. else:
    166. self.q_bias = None
    167. self.v_bias = None
    168. self.attn_drop = nn.Dropout(attn_drop)
    169. self.proj = nn.Linear(dim, dim)
    170. self.proj_drop = nn.Dropout(proj_drop)
    171. self.softmax = nn.Softmax(dim=-1)
    172. def forward(self, x, mask=None):
    173. B_, N, C = x.shape
    174. qkv_bias = None
    175. if self.q_bias is not None:
    176. qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
    177. qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
    178. qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
    179. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
    180. # cosine attention
    181. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
    182. max_tensor = torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)
    183. logit_scale = torch.clamp(self.logit_scale, max=max_tensor).exp()
    184. attn = attn * logit_scale
    185. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
    186. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
    187. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
    188. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    189. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
    190. attn = attn + relative_position_bias.unsqueeze(0)
    191. if mask is not None:
    192. nW = mask.shape[0]
    193. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    194. attn = attn.view(-1, self.num_heads, N, N)
    195. attn = self.softmax(attn)
    196. else:
    197. attn = self.softmax(attn)
    198. attn = self.attn_drop(attn)
    199. try:
    200. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    201. except:
    202. x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
    203. x = self.proj(x)
    204. x = self.proj_drop(x)
    205. return x
    206. def extra_repr(self) -> str:
    207. return f'dim={self.dim}, window_size={self.window_size}, ' \
    208. f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
    209. def flops(self, N):
    210. # calculate flops for 1 window with token length of N
    211. flops = 0
    212. # qkv = self.qkv(x)
    213. flops += N * self.dim * 3 * self.dim
    214. # attn = (q @ k.transpose(-2, -1))
    215. flops += self.num_heads * N * (self.dim // self.num_heads) * N
    216. # x = (attn @ v)
    217. flops += self.num_heads * N * N * (self.dim // self.num_heads)
    218. # x = self.proj(x)
    219. flops += N * self.dim * self.dim
    220. return flops
    221. class Mlp_v2(nn.Module):
    222. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
    223. super().__init__()
    224. out_features = out_features or in_features
    225. hidden_features = hidden_features or in_features
    226. self.fc1 = nn.Linear(in_features, hidden_features)
    227. self.act = act_layer()
    228. self.fc2 = nn.Linear(hidden_features, out_features)
    229. self.drop = nn.Dropout(drop)
    230. def forward(self, x):
    231. x = self.fc1(x)
    232. x = self.act(x)
    233. x = self.drop(x)
    234. x = self.fc2(x)
    235. x = self.drop(x)
    236. return x
    237. # add 2 functions
    238. class SwinTransformerLayer_v2(nn.Module):
    239. def __init__(self, dim, num_heads, window_size=7, shift_size=0,
    240. mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
    241. act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
    242. super().__init__()
    243. self.dim = dim
    244. # self.input_resolution = input_resolution
    245. self.num_heads = num_heads
    246. self.window_size = window_size
    247. self.shift_size = shift_size
    248. self.mlp_ratio = mlp_ratio
    249. # if min(self.input_resolution) <= self.window_size:
    250. # # if window size is larger than input resolution, we don't partition windows
    251. # self.shift_size = 0
    252. # self.window_size = min(self.input_resolution)
    253. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
    254. self.norm1 = norm_layer(dim)
    255. self.attn = WindowAttention_v2(
    256. dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
    257. qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
    258. pretrained_window_size=(pretrained_window_size, pretrained_window_size))
    259. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    260. self.norm2 = norm_layer(dim)
    261. mlp_hidden_dim = int(dim * mlp_ratio)
    262. self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    263. def create_mask(self, H, W):
    264. # calculate attention mask for SW-MSA
    265. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
    266. h_slices = (slice(0, -self.window_size),
    267. slice(-self.window_size, -self.shift_size),
    268. slice(-self.shift_size, None))
    269. w_slices = (slice(0, -self.window_size),
    270. slice(-self.window_size, -self.shift_size),
    271. slice(-self.shift_size, None))
    272. cnt = 0
    273. for h in h_slices:
    274. for w in w_slices:
    275. img_mask[:, h, w, :] = cnt
    276. cnt += 1
    277. def window_partition(x, window_size):
    278. """
    279. Args:
    280. x: (B, H, W, C)
    281. window_size (int): window size
    282. Returns:
    283. windows: (num_windows*B, window_size, window_size, C)
    284. """
    285. B, H, W, C = x.shape
    286. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    287. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    288. return windows
    289. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
    290. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
    291. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    292. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    293. return attn_mask
    294. def forward(self, x):
    295. # reshape x[b c h w] to x[b l c]
    296. _, _, H_, W_ = x.shape
    297. Padding = False
    298. if min(H_, W_) < self.window_size or H_ % self.window_size != 0 or W_ % self.window_size != 0:
    299. Padding = True
    300. # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
    301. pad_r = (self.window_size - W_ % self.window_size) % self.window_size
    302. pad_b = (self.window_size - H_ % self.window_size) % self.window_size
    303. x = F.pad(x, (0, pad_r, 0, pad_b))
    304. # print('2', x.shape)
    305. B, C, H, W = x.shape
    306. L = H * W
    307. x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
    308. # create mask from init to forward
    309. if self.shift_size > 0:
    310. attn_mask = self.create_mask(H, W).to(x.device)
    311. else:
    312. attn_mask = None
    313. shortcut = x
    314. x = x.view(B, H, W, C)
    315. # cyclic shift
    316. if self.shift_size > 0:
    317. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    318. else:
    319. shifted_x = x
    320. # partition windows
    321. def window_partition(x, window_size):
    322. """
    323. Args:
    324. x: (B, H, W, C)
    325. window_size (int): window size
    326. Returns:
    327. windows: (num_windows*B, window_size, window_size, C)
    328. """
    329. B, H, W, C = x.shape
    330. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    331. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    332. return windows
    333. def window_reverse(windows, window_size, H, W):
    334. """
    335. Args:
    336. windows: (num_windows*B, window_size, window_size, C)
    337. window_size (int): Window size
    338. H (int): Height of image
    339. W (int): Width of image
    340. Returns:
    341. x: (B, H, W, C)
    342. """
    343. B = int(windows.shape[0] / (H * W / window_size / window_size))
    344. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    345. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    346. return x
    347. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
    348. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
    349. # W-MSA/SW-MSA
    350. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
    351. # merge windows
    352. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
    353. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
    354. # reverse cyclic shift
    355. if self.shift_size > 0:
    356. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    357. else:
    358. x = shifted_x
    359. x = x.view(B, H * W, C)
    360. x = shortcut + self.drop_path(self.norm1(x))
    361. # FFN
    362. x = x + self.drop_path(self.norm2(self.mlp(x)))
    363. x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
    364. if Padding:
    365. x = x[:, :, :H_, :W_] # reverse padding
    366. return x
    367. def extra_repr(self) -> str:
    368. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
    369. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
    370. def flops(self):
    371. flops = 0
    372. H, W = self.input_resolution
    373. # norm1
    374. flops += self.dim * H * W
    375. # W-MSA/SW-MSA
    376. nW = H * W / self.window_size / self.window_size
    377. flops += nW * self.attn.flops(self.window_size * self.window_size)
    378. # mlp
    379. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
    380. # norm2
    381. flops += self.dim * H * W
    382. return flops
    383. class SwinTransformer2Block(nn.Module):
    384. def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
    385. super().__init__()
    386. self.conv = None
    387. if c1 != c2:
    388. self.conv = Conv(c1, c2)
    389. # remove input_resolution
    390. self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
    391. shift_size=0 if (i % 2 == 0) else window_size // 2) for i
    392. in range(num_layers)])
    393. def forward(self, x):
    394. if self.conv is not None:
    395. x = self.conv(x)
    396. x = self.blocks(x)
    397. return x
    398. class SwinV2_CSPB(nn.Module):
    399. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    400. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
    401. super(SwinV2_CSPB, self).__init__()
    402. c_ = int(c2) # hidden channels
    403. self.cv1 = Conv(c1, c_, 1, 1)
    404. self.cv2 = Conv(c_, c_, 1, 1)
    405. self.cv3 = Conv(2 * c_, c2, 1, 1)
    406. num_heads = c_ // 32
    407. self.m = SwinTransformer2Block(c_, c_, num_heads, n)
    408. # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
    409. def forward(self, x):
    410. x1 = self.cv1(x)
    411. y1 = self.m(x1)
    412. y2 = self.cv2(x1)
    413. return self.cv3(torch.cat((y1, y2), dim=1))

    随后在这个目录的init.py里面把你的新的东西放进去

    1. from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
    2. HGBlock, HGStem, Proto, RepC3,GAM_Attention,ResBlock_CBAM,GCT,C3STR,SwinV2_CSPB)
    3. from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
    4. GhostConv, LightConv, RepConv, SpatialAttention)
    5. from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
    6. from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
    7. MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
    8. __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
    9. 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
    10. 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
    11. 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
    12. 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
    13. 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP','ResBlock_CBAM','CBAM','GAM_Attention',
    14. 'GCT','C3STR','SwinV2_CSPB')

    然后去task.py

    在导包那里,把这个导入进去

    1. rom ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
    2. Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
    3. Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
    4. RTDETRDecoder ,Segment,CBAM,GAM_Attention,ResBlock_CBAM,GCT,C3STR,SwinV2_CSPB)

    然后同样在task.py里面搜索if m in

    把你的新的swin模块放进去

    1. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
    2. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3
    3. ,CBAM , GAM_Attention ,ResBlock_CBAM,GCT,C3STR,SwinV2_CSPB):

    最后

    命令行:python setup.py install 注册

    命令行:你的训练命令

  • 相关阅读:
    如何还原回收站已经清空的文件
    想学设计模式、想搞架构设计,先学学UML系统建模吧您
    MySQL高级10-InnoDB引擎存储架构
    解决mybatis case when 报错的问题
    你们关心的问题:产品经理面试中的职业规划及项目经历要怎么说?
    【Java Web】论坛——我收到的赞
    7 Spring Boot 整合 Spring Data JPA
    Qt编写物联网管理平台47-通用数据库设置
    【Redis】数据结构---String
    常用工具类commons-io的学习使用
  • 原文地址:https://blog.csdn.net/xty123abc/article/details/133428517