import torch.nn.functional as F
def _make_divisible(v, divisor, min_value=None):
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
class h_sigmoid(nn.Module):
def __init__(self, inplace=True, h_max=1):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
return self.relu(x + 3) * self.h_max
def __init__(self, inplace=True, h_max=1):
super(h_tanh, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
return self.relu(x + 3)*self.h_max / 3 - self.h_max
def get_act_layer(inp, oup, mode='SE1', act_relu=True, act_max=2, act_bias=True, init_a=[1.0, 0.0], reduction=4, init_b=[0.0, 0.0], g=None, act='relu', expansion=True):
SELayer(inp, oup, reduction=reduction),
nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
SELayer(inp, oup, reduction=reduction),
layer = nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
elif mode == 'LeakyReLU':
layer = nn.LeakyReLU(inplace=True) if act_relu else nn.Sequential()
layer = nn.RReLU(inplace=True) if act_relu else nn.Sequential()
layer = nn.PReLU() if act_relu else nn.Sequential()
elif mode == 'DYShiftMax':
layer = DYShiftMax(inp, oup, act_max=act_max, act_relu=act_relu, init_a=init_a, reduction=reduction, init_b=init_b, g=g, expansion=expansion)
class SELayer(nn.Module):
def __init__(self, inp, oup, reduction=4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
squeeze = get_squeeze_channels(inp, reduction)
print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
y = self.avg_pool(x_in).view(b, c)
y = self.fc(y).view(b, self.oup, 1, 1)
class DYShiftMax(nn.Module):
def __init__(self, inp, oup, reduction=4, act_max=1.0, act_relu=True, init_a=[0.0, 0.0], init_b=[0.0, 0.0], relu_before_pool=False, g=None, expansion=False):
super(DYShiftMax, self).__init__()
self.act_max = act_max * 2
self.avg_pool = nn.Sequential(
nn.ReLU(inplace=True) if relu_before_pool == True else nn.Sequential(),
self.exp = 4 if act_relu else 2
squeeze = _make_divisible(inp // reduction, 4)
print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
print('init-a: {}, init-b: {}'.format(init_a, init_b))
nn.Linear(squeeze, oup*self.exp),
if self.g !=1 and expansion:
print('group shuffle: {}, divide group: {}'.format(self.g, expansion))
index=torch.Tensor(range(inp)).view(1,inp,1,1)
print('index=',index.shape)
print('self.g: {}, self.gc: {}'.format(self.g, self.gc))
index=index.view(1,self.g,self.gc,1,1)
indexgs = torch.split(index, [1, self.g-1], dim=1)
print('indexgs[0]=',indexgs[0].shape)
print('indexgs[1]=', indexgs[1].shape)
indexgs = torch.cat((indexgs[1], indexgs[0]), dim=1)
indexs = torch.split(indexgs, [1, self.gc-1], dim=2)
print('indexs[0]=',indexs[0].shape)
print('indexs[1]=', indexs[1].shape)
indexs = torch.cat((indexs[1], indexs[0]), dim=2)
self.index = indexs.view(inp).type(torch.LongTensor)
print('self.index=',self.index.shape)
self.expansion = expansion
y = self.avg_pool(x_in).view(b, c)
y = self.fc(y).view(b, self.oup*self.exp, 1, 1)
y = (y-0.5) * self.act_max
print('y_max = ', y.shape)
n2, c2, h2, w2 = x_out.size()
x2 = x_out[:,self.index,:,:]
a1, b1, a2, b2 = torch.split(y, self.oup, dim=1)
print('torch.split(y, self.oup, dim=1)',torch.split(y, self.oup, dim=1)[0].shape)
print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[1].shape)
print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[2].shape)
print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[3].shape)
print('a1_max=',a1.shape)
z1 = x_out * a1 + x2 * b1
z2 = x_out * a2 + x2 * b2
a1, b1 = torch.split(y, self.oup, dim=1)
out = x_out * a1 + x2 * b1
def get_squeeze_channels(inp, reduction):
squeeze = inp // reduction
squeeze = _make_divisible(inp // reduction, 4)
a = torch.rand(2,16,16,16)
b = DYShiftMax(16,16,reduction=4, act_max=1.0, act_relu=True, init_a=[2.0, 2.0], init_b=[2.0, 2.0], relu_before_pool=False, g=[4,4], expansion=False)
