模型权重初始化的代码是通用的,记录一下,懒得以后每次自己敲
- def _init_weights(self):
- init_set = {
- nn.Conv2d,
- nn.Conv3d,
- nn.ConvTranspose2d,
- nn.ConvTranspose3d,
- nn.Linear,
- }
- for m in self.modules():
- if type(m) in init_set:
- nn.init.kaiming_normal_(
- m.weight.data, mode='fan_out', nonlinearity='relu', a=0
- )
- if m.bias is not None:
- fan_in, fan_out = \
- nn.init._calculate_fan_in_and_fan_out(m.weight.data)
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(m.bias, -bound, bound)
-
- # nn.init.constant_(self.unet.last.bias, -4)
- # nn.init.constant_(self.unet.last.bias, 4)