• YOLOv9独家改进|动态蛇形卷积Dynamic Snake Convolution与RepNCSPELAN4融合



    专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!


    一、改进点介绍

            Dynamic Snake Convolution是一种针对细长微弱的局部结构特征与复杂多变的全局形态特征设计的卷积模块。

            RepNCSPELAN4是YOLOv9中的特征提取模块,类似YOLOv5和v8中的C2f与C3模块。


    二、RepNCSPELAN4Dynamic模块详解

     2.1 模块简介

           RepNCSPELAN4Dynamic的主要思想:  使用Dynamic Snake Convolution与RepNCSPELAN4中融合。


    三、 RepNCSPELAN4Dynamic模块使用教程

    3.1 RepNCSPELAN4Dynamic模块的代码

    1. class RepNBottleneck_DySnakeConv(RepNBottleneck):
    2. # Standard bottleneck
    3. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
    4. super().__init__(c1, c2, shortcut, g, k, e)
    5. c_ = int(c2 * e) # hidden channels
    6. self.cv1 = RepConvN(c1, c_, k[0], 1)
    7. self.cv2 = Conv(c_, c2, k[1], s=1, g=g)
    8. self.add = shortcut and c1 == c2
    9. class RepNCSP_DySnakeConv(RepNCSP):
    10. # CSP Bottleneck with 3 convolutions
    11. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
    12. super().__init__(c1, c2, n, shortcut, g, e)
    13. c_ = int(c2 * e) # hidden channels
    14. self.cv1 = DySnakeConv(c1, c_)
    15. self.cv2 = DySnakeConv(c1, c_)
    16. self.cv3 = DySnakeConv(2 * c_, c2) # optional act=FReLU(c2)
    17. self.m = nn.Sequential(*(RepNBottleneck_DySnakeConv(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
    18. class RepNCSPELAN4DySnakeConv(RepNCSPELAN4):
    19. # csp-elan
    20. def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
    21. super().__init__(c1, c2, c3, c4, c5)
    22. self.cv1 = Conv(c1, c3, k=1, s=1)
    23. self.cv2 = nn.Sequential(RepNCSP_DySnakeConv(c3 // 2, c4, c5), DySnakeConv(c4, c4, 3))
    24. self.cv3 = nn.Sequential(RepNCSP_DySnakeConv(c4, c4, c5), DySnakeConv(c4, c4, 3))
    25. self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
    26. class DySnakeConv(nn.Module):
    27. def __init__(self, inc, ouc, k=3) -> None:
    28. super().__init__()
    29. self.conv_0 = Conv(inc, ouc, k)
    30. self.conv_x = DSConv(inc, ouc, 0, k)
    31. self.conv_y = DSConv(inc, ouc, 1, k)
    32. def forward(self, x):
    33. return torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1)
    34. class DSConv(nn.Module):
    35. def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
    36. """
    37. The Dynamic Snake Convolution
    38. :param in_ch: input channel
    39. :param out_ch: output channel
    40. :param kernel_size: the size of kernel
    41. :param extend_scope: the range to expand (default 1 for this method)
    42. :param morph: the morphology of the convolution kernel is mainly divided into two types
    43. along the x-axis (0) and the y-axis (1) (see the paper for details)
    44. :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
    45. """
    46. super(DSConv, self).__init__()
    47. # use the to learn the deformable offset
    48. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
    49. self.bn = nn.BatchNorm2d(2 * kernel_size)
    50. self.kernel_size = kernel_size
    51. # two types of the DSConv (along x-axis and y-axis)
    52. self.dsc_conv_x = nn.Conv2d(
    53. in_ch,
    54. out_ch,
    55. kernel_size=(kernel_size, 1),
    56. stride=(kernel_size, 1),
    57. padding=0,
    58. )
    59. self.dsc_conv_y = nn.Conv2d(
    60. in_ch,
    61. out_ch,
    62. kernel_size=(1, kernel_size),
    63. stride=(1, kernel_size),
    64. padding=0,
    65. )
    66. self.gn = nn.GroupNorm(out_ch // 4, out_ch)
    67. self.act = Conv.default_act
    68. self.extend_scope = extend_scope
    69. self.morph = morph
    70. self.if_offset = if_offset
    71. def forward(self, f):
    72. offset = self.offset_conv(f)
    73. offset = self.bn(offset)
    74. # We need a range of deformation between -1 and 1 to mimic the snake's swing
    75. offset = torch.tanh(offset)
    76. input_shape = f.shape
    77. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
    78. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
    79. if self.morph == 0:
    80. x = self.dsc_conv_x(deformed_feature.type(f.dtype))
    81. x = self.gn(x)
    82. x = self.act(x)
    83. return x
    84. else:
    85. x = self.dsc_conv_y(deformed_feature.type(f.dtype))
    86. x = self.gn(x)
    87. x = self.act(x)
    88. return x
    89. # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
    90. class DSC(object):
    91. def __init__(self, input_shape, kernel_size, extend_scope, morph):
    92. self.num_points = kernel_size
    93. self.width = input_shape[2]
    94. self.height = input_shape[3]
    95. self.morph = morph
    96. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
    97. # define feature map shape
    98. """
    99. B: Batch size C: Channel W: Width H: Height
    100. """
    101. self.num_batch = input_shape[0]
    102. self.num_channels = input_shape[1]
    103. """
    104. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains and )
    105. output_x: [B,1,W,K*H] coordinate map
    106. output_y: [B,1,K*W,H] coordinate map
    107. """
    108. def _coordinate_map_3D(self, offset, if_offset):
    109. device = offset.device
    110. # offset
    111. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
    112. y_center = torch.arange(0, self.width).repeat([self.height])
    113. y_center = y_center.reshape(self.height, self.width)
    114. y_center = y_center.permute(1, 0)
    115. y_center = y_center.reshape([-1, self.width, self.height])
    116. y_center = y_center.repeat([self.num_points, 1, 1]).float()
    117. y_center = y_center.unsqueeze(0)
    118. x_center = torch.arange(0, self.height).repeat([self.width])
    119. x_center = x_center.reshape(self.width, self.height)
    120. x_center = x_center.permute(0, 1)
    121. x_center = x_center.reshape([-1, self.width, self.height])
    122. x_center = x_center.repeat([self.num_points, 1, 1]).float()
    123. x_center = x_center.unsqueeze(0)
    124. if self.morph == 0:
    125. """
    126. Initialize the kernel and flatten the kernel
    127. y: only need 0
    128. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
    129. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
    130. """
    131. y = torch.linspace(0, 0, 1)
    132. x = torch.linspace(
    133. -int(self.num_points // 2),
    134. int(self.num_points // 2),
    135. int(self.num_points),
    136. )
    137. y, x = torch.meshgrid(y, x)
    138. y_spread = y.reshape(-1, 1)
    139. x_spread = x.reshape(-1, 1)
    140. y_grid = y_spread.repeat([1, self.width * self.height])
    141. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
    142. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
    143. x_grid = x_spread.repeat([1, self.width * self.height])
    144. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
    145. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
    146. y_new = y_center + y_grid
    147. x_new = x_center + x_grid
    148. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
    149. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
    150. y_offset_new = y_offset.detach().clone()
    151. if if_offset:
    152. y_offset = y_offset.permute(1, 0, 2, 3)
    153. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
    154. center = int(self.num_points // 2)
    155. # The center position remains unchanged and the rest of the positions begin to swing
    156. # This part is quite simple. The main idea is that "offset is an iterative process"
    157. y_offset_new[center] = 0
    158. for index in range(1, center):
    159. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
    160. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
    161. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
    162. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
    163. y_new = y_new.reshape(
    164. [self.num_batch, self.num_points, 1, self.width, self.height])
    165. y_new = y_new.permute(0, 3, 1, 4, 2)
    166. y_new = y_new.reshape([
    167. self.num_batch, self.num_points * self.width, 1 * self.height
    168. ])
    169. x_new = x_new.reshape(
    170. [self.num_batch, self.num_points, 1, self.width, self.height])
    171. x_new = x_new.permute(0, 3, 1, 4, 2)
    172. x_new = x_new.reshape([
    173. self.num_batch, self.num_points * self.width, 1 * self.height
    174. ])
    175. return y_new, x_new
    176. else:
    177. """
    178. Initialize the kernel and flatten the kernel
    179. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
    180. x: only need 0
    181. """
    182. y = torch.linspace(
    183. -int(self.num_points // 2),
    184. int(self.num_points // 2),
    185. int(self.num_points),
    186. )
    187. x = torch.linspace(0, 0, 1)
    188. y, x = torch.meshgrid(y, x)
    189. y_spread = y.reshape(-1, 1)
    190. x_spread = x.reshape(-1, 1)
    191. y_grid = y_spread.repeat([1, self.width * self.height])
    192. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
    193. y_grid = y_grid.unsqueeze(0)
    194. x_grid = x_spread.repeat([1, self.width * self.height])
    195. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
    196. x_grid = x_grid.unsqueeze(0)
    197. y_new = y_center + y_grid
    198. x_new = x_center + x_grid
    199. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
    200. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
    201. y_new = y_new.to(device)
    202. x_new = x_new.to(device)
    203. x_offset_new = x_offset.detach().clone()
    204. if if_offset:
    205. x_offset = x_offset.permute(1, 0, 2, 3)
    206. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
    207. center = int(self.num_points // 2)
    208. x_offset_new[center] = 0
    209. for index in range(1, center):
    210. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
    211. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
    212. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
    213. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
    214. y_new = y_new.reshape(
    215. [self.num_batch, 1, self.num_points, self.width, self.height])
    216. y_new = y_new.permute(0, 3, 1, 4, 2)
    217. y_new = y_new.reshape([
    218. self.num_batch, 1 * self.width, self.num_points * self.height
    219. ])
    220. x_new = x_new.reshape(
    221. [self.num_batch, 1, self.num_points, self.width, self.height])
    222. x_new = x_new.permute(0, 3, 1, 4, 2)
    223. x_new = x_new.reshape([
    224. self.num_batch, 1 * self.width, self.num_points * self.height
    225. ])
    226. return y_new, x_new
    227. """
    228. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
    229. output: [N,1,K*D,K*W,K*H] deformed feature map
    230. """
    231. def _bilinear_interpolate_3D(self, input_feature, y, x):
    232. device = input_feature.device
    233. y = y.reshape([-1]).float()
    234. x = x.reshape([-1]).float()
    235. zero = torch.zeros([]).int()
    236. max_y = self.width - 1
    237. max_x = self.height - 1
    238. # find 8 grid locations
    239. y0 = torch.floor(y).int()
    240. y1 = y0 + 1
    241. x0 = torch.floor(x).int()
    242. x1 = x0 + 1
    243. # clip out coordinates exceeding feature map volume
    244. y0 = torch.clamp(y0, zero, max_y)
    245. y1 = torch.clamp(y1, zero, max_y)
    246. x0 = torch.clamp(x0, zero, max_x)
    247. x1 = torch.clamp(x1, zero, max_x)
    248. input_feature_flat = input_feature.flatten()
    249. input_feature_flat = input_feature_flat.reshape(
    250. self.num_batch, self.num_channels, self.width, self.height)
    251. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
    252. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
    253. dimension = self.height * self.width
    254. base = torch.arange(self.num_batch) * dimension
    255. base = base.reshape([-1, 1]).float()
    256. repeat = torch.ones([self.num_points * self.width * self.height
    257. ]).unsqueeze(0)
    258. repeat = repeat.float()
    259. base = torch.matmul(base, repeat)
    260. base = base.reshape([-1])
    261. base = base.to(device)
    262. base_y0 = base + y0 * self.height
    263. base_y1 = base + y1 * self.height
    264. # top rectangle of the neighbourhood volume
    265. index_a0 = base_y0 - base + x0
    266. index_c0 = base_y0 - base + x1
    267. # bottom rectangle of the neighbourhood volume
    268. index_a1 = base_y1 - base + x0
    269. index_c1 = base_y1 - base + x1
    270. # get 8 grid values
    271. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
    272. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
    273. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
    274. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
    275. # find 8 grid locations
    276. y0 = torch.floor(y).int()
    277. y1 = y0 + 1
    278. x0 = torch.floor(x).int()
    279. x1 = x0 + 1
    280. # clip out coordinates exceeding feature map volume
    281. y0 = torch.clamp(y0, zero, max_y + 1)
    282. y1 = torch.clamp(y1, zero, max_y + 1)
    283. x0 = torch.clamp(x0, zero, max_x + 1)
    284. x1 = torch.clamp(x1, zero, max_x + 1)
    285. x0_float = x0.float()
    286. x1_float = x1.float()
    287. y0_float = y0.float()
    288. y1_float = y1.float()
    289. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
    290. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
    291. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
    292. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
    293. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
    294. value_c1 * vol_c1)
    295. if self.morph == 0:
    296. outputs = outputs.reshape([
    297. self.num_batch,
    298. self.num_points * self.width,
    299. 1 * self.height,
    300. self.num_channels,
    301. ])
    302. outputs = outputs.permute(0, 3, 1, 2)
    303. else:
    304. outputs = outputs.reshape([
    305. self.num_batch,
    306. 1 * self.width,
    307. self.num_points * self.height,
    308. self.num_channels,
    309. ])
    310. outputs = outputs.permute(0, 3, 1, 2)
    311. return outputs
    312. def deform_conv(self, input, offset, if_offset):
    313. y, x = self._coordinate_map_3D(offset, if_offset)
    314. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
    315. return deformed_feature

    3.2 在YOlO v9中的添加教程

    阅读YOLOv9添加模块教程或使用下文操作

            1. 将YOLOv9工程中models下common.py文件中的最下行(否则可能因类继承报错)增加模块的代码。

             2. 将YOLOv9工程中models下yolo.py文件中的第681行(可能因版本变化而变化)增加以下代码。

                RepNCSPELAN4, SPPELAN, RepNCSPELAN4DySnakeConv}:

    3.3 运行配置文件

    1. # YOLOv9
    2. # Powered bu https://blog.csdn.net/StopAndGoyyy
    3. # parameters
    4. nc: 80 # number of classes
    5. #depth_multiple: 0.33 # model depth multiple
    6. depth_multiple: 1 # model depth multiple
    7. #width_multiple: 0.25 # layer channel multiple
    8. width_multiple: 1 # layer channel multiple
    9. #activation: nn.LeakyReLU(0.1)
    10. #activation: nn.ReLU()
    11. # anchors
    12. anchors: 3
    13. # YOLOv9 backbone
    14. backbone:
    15. [
    16. [-1, 1, Silence, []],
    17. # conv down
    18. [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
    19. # conv down
    20. [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
    21. # elan-1 block
    22. [-1, 1, RepNCSPELAN4DySnakeConv, [256, 128, 64, 1]], # 3
    23. # avg-conv down
    24. [-1, 1, ADown, [256]], # 4-P3/8
    25. # elan-2 block
    26. [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
    27. # avg-conv down
    28. [-1, 1, ADown, [512]], # 6-P4/16
    29. # elan-2 block
    30. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
    31. # avg-conv down
    32. [-1, 1, ADown, [512]], # 8-P5/32
    33. # elan-2 block
    34. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
    35. ]
    36. # YOLOv9 head
    37. head:
    38. [
    39. # elan-spp block
    40. [-1, 1, SPPELAN, [512, 256]], # 10
    41. # up-concat merge
    42. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
    43. [[-1, 7], 1, Concat, [1]], # cat backbone P4
    44. # elan-2 block
    45. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
    46. # up-concat merge
    47. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
    48. [[-1, 5], 1, Concat, [1]], # cat backbone P3
    49. # elan-2 block
    50. [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
    51. # avg-conv-down merge
    52. [-1, 1, ADown, [256]],
    53. [[-1, 13], 1, Concat, [1]], # cat head P4
    54. # elan-2 block
    55. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
    56. # avg-conv-down merge
    57. [-1, 1, ADown, [512]],
    58. [[-1, 10], 1, Concat, [1]], # cat head P5
    59. # elan-2 block
    60. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
    61. # multi-level reversible auxiliary branch
    62. # routing
    63. [5, 1, CBLinear, [[256]]], # 23
    64. [7, 1, CBLinear, [[256, 512]]], # 24
    65. [9, 1, CBLinear, [[256, 512, 512]]], # 25
    66. # conv down
    67. [0, 1, Conv, [64, 3, 2]], # 26-P1/2
    68. # conv down
    69. [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
    70. # elan-1 block
    71. [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
    72. # avg-conv down fuse
    73. [-1, 1, ADown, [256]], # 29-P3/8
    74. [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
    75. # elan-2 block
    76. [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
    77. # avg-conv down fuse
    78. [-1, 1, ADown, [512]], # 32-P4/16
    79. [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
    80. # elan-2 block
    81. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
    82. # avg-conv down fuse
    83. [-1, 1, ADown, [512]], # 35-P5/32
    84. [[25, -1], 1, CBFuse, [[2]]], # 36
    85. # elan-2 block
    86. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
    87. # detection head
    88. # detect
    89. [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
    90. ]

    3.4 训练过程


    欢迎关注!


  • 相关阅读:
    Python中取2023, 9, 1——2023, 10, 31的全部时间
    Stable Diffusion webui 常用启动参数
    微信小程序签名
    前端模糊搜索
    什么是hive的静态分区和动态分区,它们又有什么区别呢?hive动态分区详解
    [100天算法】-面试题 17.17.多次搜索(day 43)
    【轻量化网络】MobileNet系列
    【web开发】6、Django(1)
    万字长文:从计算机本源深入探寻volatile和Java内存模型
    3GPP R17连接态省电特性
  • 原文地址:https://blog.csdn.net/StopAndGoyyy/article/details/136426572