classLayerNorm2d(nn.Module):def__init__(self,
embed_dim,
eps=1e-6,
data_format="channels_last")->None:super().__init__()
self.embed_dim = embed_dim
self.weight = nn.parameter.Parameter(torch.ones(embed_dim))
self.bias = nn.parameter.Parameter(torch.zeros(embed_dim))
self.eps = eps
self.data_format = data_format
assert self.data_format in["channels_last","channels_first"]
self.normalized_shape =(embed_dim,)defforward(self, x):if self.data_format =="channels_last":# N,H,W,Creturn F.layer_norm(x, self.embed_dim, self.weight, self.bias,
self.eps)elif self.data_format =="channels_first":
u = x.mean(1, keepdim=True)# N,C,H,W
s =(x - u).pow(2).mean(1, keepdim=True)
x =(x - u)/ torch.sqrt(s + self.eps)
x = self.weight[:,None,None]* x + self.bias[:,None,None]return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
3. 用法
if self.use_layer_norm:
N,C,H,W=x.shape
x = x.flatten(2).transpose(1,2)# N,C,H,W -> N,C,H*W -> N,H*W, C
hw_shape=(H,W)
x = norm(x)
x = x.reshape(-1,*hw_shape, C).permute(0,3,1,2).contiguous()# N,H,W,C -> N, C,H,W