• micronet ICCV2021


    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. def _make_divisible(v, divisor, min_value=None):
    5. """
    6. This function is taken from the original tf repo.
    7. It ensures that all layers have a channel number that is divisible by 8
    8. It can be seen here:
    9. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    10. :param v:
    11. :param divisor:
    12. :param min_value:
    13. :return:
    14. """
    15. if min_value is None:
    16. min_value = divisor
    17. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    18. # Make sure that round down does not go down by more than 10%.
    19. if new_v < 0.9 * v:
    20. new_v += divisor
    21. return new_v
    22. ########################################################################
    23. # sigmoid and tanh
    24. ########################################################################
    25. # h_sigmoid (x: [-3 3], y: [0, h_max]]
    26. class h_sigmoid(nn.Module):
    27. def __init__(self, inplace=True, h_max=1):
    28. super(h_sigmoid, self).__init__()
    29. self.relu = nn.ReLU6(inplace=inplace)
    30. self.h_max = h_max / 6
    31. def forward(self, x):
    32. return self.relu(x + 3) * self.h_max
    33. # h_tanh x: [-3, 3], y: [-h_max, h_max]
    34. class h_tanh(nn.Module):
    35. def __init__(self, inplace=True, h_max=1):
    36. super(h_tanh, self).__init__()
    37. self.relu = nn.ReLU6(inplace=inplace)
    38. self.h_max = h_max
    39. def forward(self, x):
    40. return self.relu(x + 3)*self.h_max / 3 - self.h_max
    41. ########################################################################
    42. # wrap functions
    43. ########################################################################
    44. def get_act_layer(inp, oup, mode='SE1', act_relu=True, act_max=2, act_bias=True, init_a=[1.0, 0.0], reduction=4, init_b=[0.0, 0.0], g=None, act='relu', expansion=True):
    45. layer = None
    46. if mode == 'SE1':
    47. layer = nn.Sequential(
    48. SELayer(inp, oup, reduction=reduction),
    49. nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
    50. )
    51. elif mode == 'SE0':
    52. layer = nn.Sequential(
    53. SELayer(inp, oup, reduction=reduction),
    54. )
    55. elif mode == 'NA':
    56. layer = nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
    57. elif mode == 'LeakyReLU':
    58. layer = nn.LeakyReLU(inplace=True) if act_relu else nn.Sequential()
    59. elif mode == 'RReLU':
    60. layer = nn.RReLU(inplace=True) if act_relu else nn.Sequential()
    61. elif mode == 'PReLU':
    62. layer = nn.PReLU() if act_relu else nn.Sequential()
    63. elif mode == 'DYShiftMax':
    64. layer = DYShiftMax(inp, oup, act_max=act_max, act_relu=act_relu, init_a=init_a, reduction=reduction, init_b=init_b, g=g, expansion=expansion)
    65. return layer
    66. ########################################################################
    67. # dynamic activation layers (SE, DYShiftMax, etc)
    68. ########################################################################
    69. class SELayer(nn.Module):
    70. def __init__(self, inp, oup, reduction=4):
    71. super(SELayer, self).__init__()
    72. self.oup = oup
    73. self.avg_pool = nn.AdaptiveAvgPool2d(1)
    74. # determine squeeze
    75. squeeze = get_squeeze_channels(inp, reduction)
    76. print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
    77. self.fc = nn.Sequential(
    78. nn.Linear(inp, squeeze),
    79. nn.ReLU(inplace=True),
    80. nn.Linear(squeeze, oup),
    81. h_sigmoid()
    82. )
    83. def forward(self, x):
    84. if isinstance(x, list):
    85. x_in = x[0]
    86. x_out = x[1]
    87. else:
    88. x_in = x
    89. x_out = x
    90. b, c, _, _ = x_in.size()
    91. y = self.avg_pool(x_in).view(b, c)
    92. y = self.fc(y).view(b, self.oup, 1, 1)
    93. return x_out * y
    94. class DYShiftMax(nn.Module):
    95. def __init__(self, inp, oup, reduction=4, act_max=1.0, act_relu=True, init_a=[0.0, 0.0], init_b=[0.0, 0.0], relu_before_pool=False, g=None, expansion=False):
    96. super(DYShiftMax, self).__init__()
    97. self.oup = oup
    98. self.act_max = act_max * 2
    99. self.act_relu = act_relu
    100. self.avg_pool = nn.Sequential(
    101. nn.ReLU(inplace=True) if relu_before_pool == True else nn.Sequential(),
    102. nn.AdaptiveAvgPool2d(1)
    103. )
    104. self.exp = 4 if act_relu else 2
    105. self.init_a = init_a
    106. self.init_b = init_b
    107. # determine squeeze
    108. squeeze = _make_divisible(inp // reduction, 4)
    109. if squeeze < 4:
    110. squeeze = 4
    111. print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
    112. print('init-a: {}, init-b: {}'.format(init_a, init_b))
    113. self.fc = nn.Sequential(
    114. nn.Linear(inp, squeeze),
    115. nn.ReLU(inplace=True),
    116. nn.Linear(squeeze, oup*self.exp),
    117. h_sigmoid()
    118. )
    119. if g is None:
    120. g = 1
    121. self.g = g[1]
    122. if self.g !=1 and expansion:
    123. self.g = inp // self.g
    124. print('group shuffle: {}, divide group: {}'.format(self.g, expansion))
    125. self.gc = inp//self.g
    126. index=torch.Tensor(range(inp)).view(1,inp,1,1)#b,c,1,1
    127. print('index=',index.shape)
    128. # print('range(inp)=',range(inp))#整数序列
    129. print('self.g: {}, self.gc: {}'.format(self.g, self.gc))
    130. index=index.view(1,self.g,self.gc,1,1)
    131. indexgs = torch.split(index, [1, self.g-1], dim=1)
    132. print('indexgs[0]=',indexgs[0].shape)
    133. print('indexgs[1]=', indexgs[1].shape)
    134. indexgs = torch.cat((indexgs[1], indexgs[0]), dim=1)
    135. indexs = torch.split(indexgs, [1, self.gc-1], dim=2)
    136. print('indexs[0]=',indexs[0].shape)
    137. print('indexs[1]=', indexs[1].shape)
    138. indexs = torch.cat((indexs[1], indexs[0]), dim=2)
    139. self.index = indexs.view(inp).type(torch.LongTensor)
    140. print('self.index=',self.index.shape)
    141. self.expansion = expansion
    142. def forward(self, x):
    143. x_in = x
    144. x_out = x
    145. b, c, _, _ = x_in.size()
    146. y = self.avg_pool(x_in).view(b, c)
    147. y = self.fc(y).view(b, self.oup*self.exp, 1, 1)
    148. print('y = ',y.shape)
    149. y = (y-0.5) * self.act_max
    150. print('y_max = ', y.shape)
    151. n2, c2, h2, w2 = x_out.size()
    152. x2 = x_out[:,self.index,:,:]
    153. print('x2 = ',x2.shape)
    154. if self.exp == 4:
    155. a1, b1, a2, b2 = torch.split(y, self.oup, dim=1)
    156. print('a1 = ',a1.shape)#b,c,1,1
    157. print('torch.split(y, self.oup, dim=1)',torch.split(y, self.oup, dim=1)[0].shape)
    158. print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[1].shape)
    159. print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[2].shape)
    160. print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[3].shape)
    161. a1 = a1 + self.init_a[0]
    162. print('a1_max=',a1.shape)#b,c,1,1
    163. a2 = a2 + self.init_a[1]
    164. b1 = b1 + self.init_b[0]
    165. b2 = b2 + self.init_b[1]
    166. z1 = x_out * a1 + x2 * b1
    167. z2 = x_out * a2 + x2 * b2
    168. out = torch.max(z1, z2)
    169. elif self.exp == 2:
    170. a1, b1 = torch.split(y, self.oup, dim=1)
    171. a1 = a1 + self.init_a[0]
    172. b1 = b1 + self.init_b[0]
    173. out = x_out * a1 + x2 * b1
    174. return out
    175. def get_squeeze_channels(inp, reduction):
    176. if reduction == 4:
    177. squeeze = inp // reduction
    178. else:
    179. squeeze = _make_divisible(inp // reduction, 4)
    180. return squeeze
    181. a = torch.rand(2,16,16,16)
    182. b = DYShiftMax(16,16,reduction=4, act_max=1.0, act_relu=True, init_a=[2.0, 2.0], init_b=[2.0, 2.0], relu_before_pool=False, g=[4,4], expansion=False)
    183. c = b(a)
    184. # print(b)

  • 相关阅读:
    Python爬虫之Requests库
    consul python sdk
    【Unity C#_菜单Window开发系列_Inspector Component UnityEditor开发】
    python+pytest接口自动化之测试函数、测试类/测试方法的封装
    蓝桥杯(七段码,C++)
    品牌线上渠道管控,如何考察第三方控价公司
    conda 复制系统环境
    如何决定在创建利基(niche)站时选择中文站还是英文站
    数字化转型接力赛接棒 金融壹账通迎“新帅”
    Python大数据之Python进阶(一)介绍
  • 原文地址:https://blog.csdn.net/zouxiaolv/article/details/126155376