• U-Net: Convolutional Networks for Biomedical Images Segmentation


    paper:  U-Net: Convolutional Networks for Biomedical Image Segmentation

    创新点

    1. 提出了U型encoder-decoder的网络结构,通过skip-connection操作更好的融合浅层的位置信息和深层的语义信息。U-Net借鉴FCN采用全卷积的结构,相比于FCN一个重要的改变是在上采样部分也有大量的特征通道,这允许网络将上下文信息传播到更高分辨率的层。
    2. 医疗图像分割的任务,训练数据非常少,作者通过应用弹性形变做了大量的数据增强。
    3. 提出使用加权损失。

     

    一些需要注意的实现细节

    1. 原论文实现中没有使用padding,因此输出feature map的分辨率逐渐减小,在下面介绍的mmsegmentation的实现中采用了padding,因此当stride=1时输出特征图的分辨率不变。
    2. FCN中skip-connection融合浅层信息与深层信息是通过add的方式,而U-Net中是通过concatenate的方式.

    实现细节解析

    以MMSegmentation中unet的实现为例,假设batch_size=4,输入shape为(4, 3, 480, 480)。

    Backbone

    • encode阶段共5个stage,每个stage中有一个ConvBlock,ConvBlock由2个Conv-BN-Relu组成。除了第1个stage,后4个stage在ConvBlock前都有1个2x2-s2的maxpool。每个stage的第1个conv的输出通道x2。因此encode阶段每个stage的输出shape分别为(4, 64, 480, 480)、(4, 128, 240, 240)、(4, 256, 120, 120)、(4, 512, 60, 60)、(4, 1024, 30, 30)。
    • decode阶段共4个stage,和encode后4个降采样的stage对应。每个stage分为upsample、concatenate、conv三个步骤。upsample由一个scale_factor=2的bilinear插值和1个Conv-BN-Relu组成,其中的conv是1x1-s1通道数减半的卷积。第二步concatenate将upsample的输出与encode阶段分辨率大小相同的输出沿通道方向拼接到一起。第三步是一个ConvBlock,和encode阶段一样,这里的ConvBlock也由两个Conv-BN-Relu组成,因为upsample后通道数减半,但和encode对应输出拼接后通道数又还原回去了,这里的ConvBlock中的第一个conv再将输出通道数减半。因此decode阶段每个stage的输出shape分别为(4, 1024, 30, 30)、(4, 512, 60, 60)、(4, 256, 120, 120)、(4, 128 , 240, 240)、(4, 64, 480, 480)。注意decode共4个stage,因此实际的输出是后4个,第一个输出就是encode最后一个stage的输出。

    FCN Head

    • backbone中decode阶段的最后一个stage的输出(4, 64, 480, 480)作为head的输入。首先经过一个3x3-s1的conv-bn-relu,通道数不变。然后经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。

    Loss

    • loss采用cross-entropy loss

    Auxiliary Head

    • backbone中decode阶段的倒数第二个stage的输出(4, 128, 240, 240)作为auxiliary head的输入。经过一个3x3-s1的conv-bn-relu,输出通道数减半为64。经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。
    • 辅助分支的Loss也是cross-entropy loss,注意这个分支的最终输出分辨率为原始gt的一半,因此在计算loss时需要先通过双线性插值上采样。

    模型的完整结构

    1. EncoderDecoder(
    2. (backbone): UNet(
    3. (encoder): ModuleList(
    4. (0): Sequential(
    5. (0): BasicConvBlock(
    6. (convs): Sequential(
    7. (0): ConvModule(
    8. (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    9. (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    10. (activate): ReLU(inplace=True)
    11. )
    12. (1): ConvModule(
    13. (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    14. (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    15. (activate): ReLU(inplace=True)
    16. )
    17. )
    18. )
    19. )
    20. (1): Sequential(
    21. (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    22. (1): BasicConvBlock(
    23. (convs): Sequential(
    24. (0): ConvModule(
    25. (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    26. (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    27. (activate): ReLU(inplace=True)
    28. )
    29. (1): ConvModule(
    30. (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    31. (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    32. (activate): ReLU(inplace=True)
    33. )
    34. )
    35. )
    36. )
    37. (2): Sequential(
    38. (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    39. (1): BasicConvBlock(
    40. (convs): Sequential(
    41. (0): ConvModule(
    42. (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    43. (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    44. (activate): ReLU(inplace=True)
    45. )
    46. (1): ConvModule(
    47. (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    48. (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    49. (activate): ReLU(inplace=True)
    50. )
    51. )
    52. )
    53. )
    54. (3): Sequential(
    55. (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    56. (1): BasicConvBlock(
    57. (convs): Sequential(
    58. (0): ConvModule(
    59. (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    60. (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    61. (activate): ReLU(inplace=True)
    62. )
    63. (1): ConvModule(
    64. (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    65. (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    66. (activate): ReLU(inplace=True)
    67. )
    68. )
    69. )
    70. )
    71. (4): Sequential(
    72. (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    73. (1): BasicConvBlock(
    74. (convs): Sequential(
    75. (0): ConvModule(
    76. (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    77. (bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    78. (activate): ReLU(inplace=True)
    79. )
    80. (1): ConvModule(
    81. (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    82. (bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    83. (activate): ReLU(inplace=True)
    84. )
    85. )
    86. )
    87. )
    88. )
    89. (decoder): ModuleList(
    90. (0): UpConvBlock(
    91. (conv_block): BasicConvBlock(
    92. (convs): Sequential(
    93. (0): ConvModule(
    94. (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    95. (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    96. (activate): ReLU(inplace=True)
    97. )
    98. (1): ConvModule(
    99. (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    100. (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    101. (activate): ReLU(inplace=True)
    102. )
    103. )
    104. )
    105. (upsample): InterpConv(
    106. (interp_upsample): Sequential(
    107. (0): Upsample()
    108. (1): ConvModule(
    109. (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    110. (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    111. (activate): ReLU(inplace=True)
    112. )
    113. )
    114. )
    115. )
    116. (1): UpConvBlock(
    117. (conv_block): BasicConvBlock(
    118. (convs): Sequential(
    119. (0): ConvModule(
    120. (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    121. (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    122. (activate): ReLU(inplace=True)
    123. )
    124. (1): ConvModule(
    125. (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    126. (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    127. (activate): ReLU(inplace=True)
    128. )
    129. )
    130. )
    131. (upsample): InterpConv(
    132. (interp_upsample): Sequential(
    133. (0): Upsample()
    134. (1): ConvModule(
    135. (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    136. (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    137. (activate): ReLU(inplace=True)
    138. )
    139. )
    140. )
    141. )
    142. (2): UpConvBlock(
    143. (conv_block): BasicConvBlock(
    144. (convs): Sequential(
    145. (0): ConvModule(
    146. (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    147. (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    148. (activate): ReLU(inplace=True)
    149. )
    150. (1): ConvModule(
    151. (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    152. (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    153. (activate): ReLU(inplace=True)
    154. )
    155. )
    156. )
    157. (upsample): InterpConv(
    158. (interp_upsample): Sequential(
    159. (0): Upsample()
    160. (1): ConvModule(
    161. (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    162. (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    163. (activate): ReLU(inplace=True)
    164. )
    165. )
    166. )
    167. )
    168. (3): UpConvBlock(
    169. (conv_block): BasicConvBlock(
    170. (convs): Sequential(
    171. (0): ConvModule(
    172. (conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    173. (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    174. (activate): ReLU(inplace=True)
    175. )
    176. (1): ConvModule(
    177. (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    178. (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    179. (activate): ReLU(inplace=True)
    180. )
    181. )
    182. )
    183. (upsample): InterpConv(
    184. (interp_upsample): Sequential(
    185. (0): Upsample()
    186. (1): ConvModule(
    187. (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    188. (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    189. (activate): ReLU(inplace=True)
    190. )
    191. )
    192. )
    193. )
    194. )
    195. )
    196. init_cfg=[{'type': 'Kaiming', 'layer': 'Conv2d'}, {'type': 'Constant', 'val': 1, 'layer': ['_BatchNorm', 'GroupNorm']}]
    197. (decode_head): FCNHead(
    198. input_transform=None, ignore_index=255, align_corners=False
    199. (loss_decode): CrossEntropyLoss(avg_non_ignore=False)
    200. (conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    201. (dropout): Dropout2d(p=0.1, inplace=False)
    202. (convs): Sequential(
    203. (0): ConvModule(
    204. (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    205. (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    206. (activate): ReLU(inplace=True)
    207. )
    208. )
    209. )
    210. init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
    211. (auxiliary_head): FCNHead(
    212. input_transform=None, ignore_index=255, align_corners=False
    213. (loss_decode): CrossEntropyLoss(avg_non_ignore=False)
    214. (conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    215. (dropout): Dropout2d(p=0.1, inplace=False)
    216. (convs): Sequential(
    217. (0): ConvModule(
    218. (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    219. (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    220. (activate): ReLU(inplace=True)
    221. )
    222. )
    223. )
    224. init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
    225. )

  • 相关阅读:
    三点式振荡器
    基于Python实现可视化分析中国500强排行榜数据的设计与实现
    使用yolov8 进行实例分割训练
    2022最全Java后端面试真题、两万字1000+道堪称史上最强的面试题不接受任何反驳
    常识性概念知识图谱
    docker 构建filebeat镜像
    穿透三翼鸟“三不卖”原则:看似“逆向”背后的十足远见
    Python绘制三维图详解
    c++ 空类的大小
    < 今日份知识点:浅述对 “ Vue 插槽 (slot) ” 的理解 以及 插槽的应用场景 >
  • 原文地址:https://blog.csdn.net/ooooocj/article/details/125597442