• 通过pytorch转换得到ms模型,训练模式下输出和pytorch模型一样,验证模式下通过batchnorm2d算子的输出不同


    pytorch源码输入三维度,使用batchnorm1d算子,mindspore的batchnorm1d输入为2维,所以我选择使用ms的batchnorm2d算子(输入为4维),对输入先进行升维度,得到batchnorm2d的输出在进行降维处理。通过pytorch转换得到的ms模型在训练模式下的输出相同,在测试模式下通过batchnorm2d算子时(对应Pytorch的batchnorm1d算子,使用的权重也为bn1d的权重)得到的输出不同。

    1. # Ms
    2. class PFNLayer(nn.Cell):
    3. def __init__(self, in_channels, out_channels, norm_cfg=None, last_layer=False):
    4. super(PFNLayer, self).__init__()
    5. self.last_vfe = last_layer
    6. if not self.last_vfe:
    7. out_channels = out_channels // 2
    8. self.units = out_channels
    9. self.linear = nn.Dense(in_channels, self.units, has_bias=False)
    10. if norm_cfg is None:
    11. self.norm = nn.BatchNorm2d(self.units, eps=1e-3, momentum=0.99, use_batch_statistics=True)
    12. self.transpose = ops.Transpose()
    13. self.tile = ops.Tile()
    14. self.concat = ops.Concat(axis=2)
    15. self.expand_dims = ops.ExpandDims()
    16. self.argmax_w_value = ops.ArgMaxWithValue(axis=1, keep_dims=True)
    17. def construct(self, inputs):
    18. """forward graph"""
    19. x = self.linear(inputs)
    20. x = self.expand_dims(x, 0)
    21. x = self.norm(x.transpose((0, 3, 1, 2))).transpose((0, 2, 3, 1)).squeeze(axis=0) # MS的bn2d的升降维度处理
    22. x = ops.ReLU()(x)
    23. x_max = self.argmax_w_value(x)[1]
    24. if self.last_vfe:
    25. return x_max
    26. x_repeat = self.tile(x_max, (1, inputs.shape[1], 1))
    27. x_concatenated = self.concat([x, x_repeat])
    28. return x_concatenated
    29. norm_cfg = {
    30. # format: layer_type: (abbreviation, module)
    31. "BN": ("bn", nn.BatchNorm2d),
    32. "BN1d": ("bn1d", nn.BatchNorm1d),
    33. "GN": ("gn", nn.GroupNorm),
    34. }
    35. def build_norm_layer(cfg, num_features, postfix=""):
    36. """ Build normalization layer
    37. """
    38. assert isinstance(cfg, dict) and "type" in cfg
    39. cfg_ = cfg.copy()
    40. layer_type = cfg_.pop("type")
    41. if layer_type not in norm_cfg:
    42. raise KeyError("Unrecognized norm type {}".format(layer_type))
    43. else:
    44. abbr, norm_layer = norm_cfg[layer_type]
    45. if norm_layer is None:
    46. raise NotImplementedError
    47. assert isinstance(postfix, (int, str))
    48. name = abbr + str(postfix)
    49. requires_grad = cfg_.pop("requires_grad", True)
    50. cfg_.setdefault("eps", 1e-5)
    51. if layer_type != "GN":
    52. layer = norm_layer(num_features, **cfg_)
    53. # if layer_type == 'SyncBN':
    54. # layer._specify_ddp_gpu_num(1)
    55. else:
    56. assert "num_groups" in cfg_
    57. layer = norm_layer(num_channels=num_features, **cfg_)
    58. for param in layer.parameters():
    59. param.requires_grad = requires_grad
    60. return name, layer
    61. # pytorch
    62. class PFNLayer(nn.Module):
    63. def __init__(self, in_channels, out_channels, norm_cfg=None, last_layer=False):
    64. super().__init__()
    65. self.name = "PFNLayer"
    66. self.last_vfe = last_layer
    67. if not self.last_vfe:
    68. out_channels = out_channels // 2
    69. self.units = out_channels
    70. if norm_cfg is None:
    71. norm_cfg = dict(type="BN1d", eps=1e-3, momentum=0.01)
    72. self.norm_cfg = norm_cfg
    73. self.linear = nn.Linear(in_channels, self.units, bias=False)
    74. self.norm = build_norm_layer(self.norm_cfg, self.units)[1]
    75. def forward(self, inputs):
    76. x = self.linear(inputs)
    77. torch.backends.cudnn.enabled = False
    78. x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() # pytorch的bn1d
    79. torch.backends.cudnn.enabled = True
    80. x = F.relu(x)
    81. x_max = torch.max(x, dim=1, keepdim=True)[0]
    82. if self.last_vfe:
    83. return x_max
    84. else:
    85. x_repeat = x_max.repeat(1, inputs.shape[1], 1)
    86. x_concatenated = torch.cat([x, x_repeat], dim=2)
    87. return x_concatenated

    ****************************************************解答*****************************************************

    您好,batchnorm是1d还是2d应该要统一才有对比意义吧。而且示例里面相当于batch size=1,这个对于batchnorm来说也失去了意义。

  • 相关阅读:
    Js中一些数组常用API总结
    Vivado生成sdf文件命令
    深入剖析Sgementation fault原理
    Spring Security(8)
    vue的基本使用
    springBoot集成websocket实现消息实时推送提醒
    常用 numpy 函数(长期更新)
    DEVICENET 总线转MODBUS-TCP协议网关连接台达plc配置方法
    java 字节流写入文件内容实现换行
    Linux线程
  • 原文地址:https://blog.csdn.net/weixin_45666880/article/details/127731493