- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- def _make_divisible(v, divisor, min_value=None):
- """
- This function is taken from the original tf repo.
- It ensures that all layers have a channel number that is divisible by 8
- It can be seen here:
- https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
- :param v:
- :param divisor:
- :param min_value:
- :return:
- """
- if min_value is None:
- min_value = divisor
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
- # Make sure that round down does not go down by more than 10%.
- if new_v < 0.9 * v:
- new_v += divisor
- return new_v
-
-
- ########################################################################
- # sigmoid and tanh
- ########################################################################
- # h_sigmoid (x: [-3 3], y: [0, h_max]]
- class h_sigmoid(nn.Module):
- def __init__(self, inplace=True, h_max=1):
- super(h_sigmoid, self).__init__()
- self.relu = nn.ReLU6(inplace=inplace)
- self.h_max = h_max / 6
-
- def forward(self, x):
- return self.relu(x + 3) * self.h_max
-
- # h_tanh x: [-3, 3], y: [-h_max, h_max]
- class h_tanh(nn.Module):
- def __init__(self, inplace=True, h_max=1):
- super(h_tanh, self).__init__()
- self.relu = nn.ReLU6(inplace=inplace)
- self.h_max = h_max
-
- def forward(self, x):
- return self.relu(x + 3)*self.h_max / 3 - self.h_max
-
-
- ########################################################################
- # wrap functions
- ########################################################################
-
- def get_act_layer(inp, oup, mode='SE1', act_relu=True, act_max=2, act_bias=True, init_a=[1.0, 0.0], reduction=4, init_b=[0.0, 0.0], g=None, act='relu', expansion=True):
- layer = None
- if mode == 'SE1':
- layer = nn.Sequential(
- SELayer(inp, oup, reduction=reduction),
- nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
- )
- elif mode == 'SE0':
- layer = nn.Sequential(
- SELayer(inp, oup, reduction=reduction),
- )
- elif mode == 'NA':
- layer = nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
- elif mode == 'LeakyReLU':
- layer = nn.LeakyReLU(inplace=True) if act_relu else nn.Sequential()
- elif mode == 'RReLU':
- layer = nn.RReLU(inplace=True) if act_relu else nn.Sequential()
- elif mode == 'PReLU':
- layer = nn.PReLU() if act_relu else nn.Sequential()
- elif mode == 'DYShiftMax':
- layer = DYShiftMax(inp, oup, act_max=act_max, act_relu=act_relu, init_a=init_a, reduction=reduction, init_b=init_b, g=g, expansion=expansion)
- return layer
-
- ########################################################################
- # dynamic activation layers (SE, DYShiftMax, etc)
- ########################################################################
-
- class SELayer(nn.Module):
- def __init__(self, inp, oup, reduction=4):
- super(SELayer, self).__init__()
- self.oup = oup
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
-
- # determine squeeze
- squeeze = get_squeeze_channels(inp, reduction)
- print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
-
-
- self.fc = nn.Sequential(
- nn.Linear(inp, squeeze),
- nn.ReLU(inplace=True),
- nn.Linear(squeeze, oup),
- h_sigmoid()
- )
-
- def forward(self, x):
- if isinstance(x, list):
- x_in = x[0]
- x_out = x[1]
- else:
- x_in = x
- x_out = x
- b, c, _, _ = x_in.size()
- y = self.avg_pool(x_in).view(b, c)
- y = self.fc(y).view(b, self.oup, 1, 1)
- return x_out * y
-
- class DYShiftMax(nn.Module):
- def __init__(self, inp, oup, reduction=4, act_max=1.0, act_relu=True, init_a=[0.0, 0.0], init_b=[0.0, 0.0], relu_before_pool=False, g=None, expansion=False):
- super(DYShiftMax, self).__init__()
- self.oup = oup
- self.act_max = act_max * 2
- self.act_relu = act_relu
- self.avg_pool = nn.Sequential(
- nn.ReLU(inplace=True) if relu_before_pool == True else nn.Sequential(),
- nn.AdaptiveAvgPool2d(1)
- )
-
- self.exp = 4 if act_relu else 2
- self.init_a = init_a
- self.init_b = init_b
-
- # determine squeeze
- squeeze = _make_divisible(inp // reduction, 4)
- if squeeze < 4:
- squeeze = 4
- print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
- print('init-a: {}, init-b: {}'.format(init_a, init_b))
-
- self.fc = nn.Sequential(
- nn.Linear(inp, squeeze),
- nn.ReLU(inplace=True),
- nn.Linear(squeeze, oup*self.exp),
- h_sigmoid()
- )
- if g is None:
- g = 1
- self.g = g[1]
- if self.g !=1 and expansion:
- self.g = inp // self.g
- print('group shuffle: {}, divide group: {}'.format(self.g, expansion))
- self.gc = inp//self.g
- index=torch.Tensor(range(inp)).view(1,inp,1,1)#b,c,1,1
- print('index=',index.shape)
- # print('range(inp)=',range(inp))#整数序列
- print('self.g: {}, self.gc: {}'.format(self.g, self.gc))
- index=index.view(1,self.g,self.gc,1,1)
- indexgs = torch.split(index, [1, self.g-1], dim=1)
- print('indexgs[0]=',indexgs[0].shape)
- print('indexgs[1]=', indexgs[1].shape)
- indexgs = torch.cat((indexgs[1], indexgs[0]), dim=1)
- indexs = torch.split(indexgs, [1, self.gc-1], dim=2)
- print('indexs[0]=',indexs[0].shape)
- print('indexs[1]=', indexs[1].shape)
- indexs = torch.cat((indexs[1], indexs[0]), dim=2)
- self.index = indexs.view(inp).type(torch.LongTensor)
- print('self.index=',self.index.shape)
- self.expansion = expansion
-
- def forward(self, x):
- x_in = x
- x_out = x
-
- b, c, _, _ = x_in.size()
- y = self.avg_pool(x_in).view(b, c)
- y = self.fc(y).view(b, self.oup*self.exp, 1, 1)
- print('y = ',y.shape)
- y = (y-0.5) * self.act_max
- print('y_max = ', y.shape)
-
- n2, c2, h2, w2 = x_out.size()
- x2 = x_out[:,self.index,:,:]
- print('x2 = ',x2.shape)
-
- if self.exp == 4:
- a1, b1, a2, b2 = torch.split(y, self.oup, dim=1)
- print('a1 = ',a1.shape)#b,c,1,1
- print('torch.split(y, self.oup, dim=1)',torch.split(y, self.oup, dim=1)[0].shape)
- print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[1].shape)
- print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[2].shape)
- print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[3].shape)
-
- a1 = a1 + self.init_a[0]
- print('a1_max=',a1.shape)#b,c,1,1
- a2 = a2 + self.init_a[1]
-
- b1 = b1 + self.init_b[0]
- b2 = b2 + self.init_b[1]
-
- z1 = x_out * a1 + x2 * b1
- z2 = x_out * a2 + x2 * b2
-
- out = torch.max(z1, z2)
-
- elif self.exp == 2:
- a1, b1 = torch.split(y, self.oup, dim=1)
- a1 = a1 + self.init_a[0]
- b1 = b1 + self.init_b[0]
- out = x_out * a1 + x2 * b1
-
- return out
-
- def get_squeeze_channels(inp, reduction):
- if reduction == 4:
- squeeze = inp // reduction
- else:
- squeeze = _make_divisible(inp // reduction, 4)
- return squeeze
-
- a = torch.rand(2,16,16,16)
- b = DYShiftMax(16,16,reduction=4, act_max=1.0, act_relu=True, init_a=[2.0, 2.0], init_b=[2.0, 2.0], relu_before_pool=False, g=[4,4], expansion=False)
- c = b(a)
- # print(b)