• 深度学习论文精读[5]:Attention UNet


    5ef3b355d32ea799bcaecce10ec7ae56.jpeg

    以CNN为基础的编解码结构在图像分割上展现出了卓越的效果,尤其是医学图像的自动分割上。但一些研究认为以往的FCN和UNet等分割网络存在计算资源和模型参数的过度和重复使用,例如相似的低层次特征被级联内的所有网络重复提取。针对这类普遍性的问题,相关研究提出了给UNet添加注意力门控(Attention Gates, AGs)的方法,形成一个新的图像分割网络结构:Attention UNet。提出Attention UNet的论文为Attention U-Net: Learning Where to Look for the Pancreas,发表在2018年CVPR上。注意力机制原先是在自然语言处理领域被提出并逐渐得到广泛应用的一种新型结构,旨在模仿人的注意力机制,有针对性的聚焦数据中的突出特征,能够使得模型更加高效。

    Attention UNet的网络结构如下图所示,需要注意的是,论文中给出的3D版本的卷积网络。其中编码器部分跟UNet编码器基本一致,主要的变化在于解码器部分。其结构简要描述如下:编码器部分,输入图像经过两组3*3*3的3D卷积和ReLU激活,然后再进行最大池化下采样,经过3组这样的卷积-池化块之后,网络进入到解码器部分。编码器最后一层的特征图除了直接进行上采样外,还与来自编码器的特征图进行注意力门控计算,然后再与上采样的特征图进行合并,经过三次这样的上采样块之后即可得到最终的分割输出图。相比于普通UNet的解码器,Attention UNet会将解码器中的特征与编码器连接过来的特征进行注意力门控处理,然后再与上采样进行拼接。经过注意力门控处理后得到的特征图会包含不同空间位置的重要性信息,使得模型能够重点关注某些目标区域。

    0596d83aa5853e6fd1630a594ea16ee9.png

    我们将Attention UNet的注意力门控单独拿出来进行分析,看AGs是如何让模型能够聚焦到目标区域的。如图中上图所示,将Attention UNet网络中的一个上采样块单独拿出来,其中x_l为来自同层编码器的输出特征图,g表示由解码器部分用于上采样的特征图,这里同时也作为注意力门控的门控信号参数与x_l的注意力计算,而x^hat_l即为经过注意力门控计算后的特征图,此时x^hat_l是包含了空间位置重要性信息的特征图,再将其与下一层上采样后的特征图进行合并才得到该上采样块最终的输出。

    1906bbe4c34f8631178dc47f74bfeef5.png

    ea6b6a3dc51a2cb50f9412a3fa1272e7.png

    将x_l和g_i计算得到的注意力系数再次与x_l相乘即可得到x^hat_l,这种经过与注意力系数相乘后的特征图会让图像中不相关的区域值变小,目标区域的值相对会变大,提升网络预测速度同时,也会提高图像的分割精度。论文中的各项实验结果也表明,经过注意力门控加成后后UNet,效果均要优于原始的UNet。下述代码给出了Attention UNet的一个2D参考实现,并且下采样次数由论文中的3次改为了4次。

    1. ### 定义Attention UNet类
    2. class Att_UNet(nn.Module):
    3. def __init__(self,img_ch=3,output_ch=1):
    4. super(Att_UNet, self).__init__()
    5. self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
    6. self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
    7. self.Conv2 = conv_block(ch_in=64, ch_out=128)
    8. self.Conv3 = conv_block(ch_in=128, ch_out=256)
    9. self.Conv4 = conv_block(ch_in=256, ch_out=512)
    10. self.Conv5 = conv_block(ch_in=512, ch_out=1024)
    11. self.Up5 = up_conv(ch_in=1024, ch_out=512)
    12. self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
    13. self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
    14. self.Up4 = up_conv(ch_in=512, ch_out=256)
    15. self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
    16. self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
    17. self.Up3 = up_conv(ch_in=256, ch_out=128)
    18. self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
    19. self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
    20. self.Up2 = up_conv(ch_in=128, ch_out=64)
    21. self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
    22. self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
    23. self.Conv_1x1 =
    24. nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
    25. ### 定义前向传播流程
    26. def forward(self,x):
    27. # 编码器部分
    28. x1 = self.Conv1(x)
    29. x2 = self.Maxpool(x1)
    30. x2 = self.Conv2(x2)
    31. x3 = self.Maxpool(x2)
    32. x3 = self.Conv3(x3)
    33. x4 = self.Maxpool(x3)
    34. x4 = self.Conv4(x4)
    35. x5 = self.Maxpool(x4)
    36. x5 = self.Conv5(x5)
    37. # 解码器+连接部分
    38. d5 = self.Up5(x5)
    39. x4 = self.Att5(g=d5,x=x4)
    40. d5 = torch.cat((x4,d5),dim=1)
    41. d5 = self.Up_conv5(d5)
    42. d4 = self.Up4(d5)
    43. x3 = self.Att4(g=d4,x=x3)
    44. d4 = torch.cat((x3,d4),dim=1)
    45. d4 = self.Up_conv4(d4)
    46. d3 = self.Up3(d4)
    47. x2 = self.Att3(g=d3,x=x2)
    48. d3 = torch.cat((x2,d3),dim=1)
    49. d3 = self.Up_conv3(d3)
    50. d2 = self.Up2(d3)
    51. x1 = self.Att2(g=d2,x=x1)
    52. d2 = torch.cat((x1,d2),dim=1)
    53. d2 = self.Up_conv2(d2)
    54. d1 = self.Conv_1x1(d2)
    55. return d1
    56. ### 定义Attention门控块
    57. class Attention_block(nn.Module):
    58. def __init__(self, F_g, F_l, F_int):
    59. super(Attention_block, self).__init__()
    60. # 注意力门控向量
    61. self.W_g = nn.Sequential(
    62. nn.Conv2d(F_g, F_int,
    63. kernel_size=1, stride=1, padding=0, bias=True),
    64. nn.BatchNorm2d(F_int)
    65. )
    66. # 同层编码器特征图向量
    67. self.W_x = nn.Sequential(
    68. nn.Conv2d(F_l, F_int,
    69. kernel_size=1,stride=1,padding=0,bias=True),
    70. nn.BatchNorm2d(F_int)
    71. )
    72. # ReLU激活函数
    73. self.relu = nn.ReLU(inplace=True)
    74. # 卷积+BN+sigmoid激活函数
    75. self.psi = nn.Sequential(
    76. nn.Conv2d(F_int, 1,
    77. kernel_size=1, stride=1, padding=0, bias=True),
    78. nn.BatchNorm2d(1),
    79. nn.Sigmoid()
    80. )
    81. ### Attention门控的前向计算流程
    82. def forward(self,g,x):
    83. g1 = self.W_g(g)
    84. x1 = self.W_x(x)
    85. psi = self.relu(g1+x1)
    86. psi = self.psi(psi)
    87. return x*psi

    63f380ba252dc94161e7de3dd5955aa5.png

    总结来说,Attention UNet提出了在原始UNet基础添加注意力门控单元,注意力得分能够使得图像分割时聚焦到目标区域,该结构作为一个通用结构可以添加到任何任务类型的神经网络结构中,在语义分割网络中对前景目标区域的像素更具有敏感度。Attention UNet壮大了UNet家族网络,此后基于其的改进版本也层出不穷。

    往期精彩:

     深度学习论文精读[1]:FCN全卷积网络

     深度学习论文精读[2]:UNet网络

     深度学习论文精读[3]:SegNet

     深度学习论文精读[4]:RefineNet

     讲解视频来了!机器学习 公式推导与代码实现开录!

     完结!《机器学习 公式推导与代码实现》全书1-26章PPT下载

  • 相关阅读:
    10 次面试 9 次被刷?吃透这 500 道大厂 Java 高频面试题后,怒斩 offer
    图解LeetCode——652. 寻找重复的子树(难度:中等)
    面向机器理解的多视角上下文匹配
    【无标题】
    AI智能识别技术如何助力校园智慧食堂建设、保障餐饮卫生安全?
    WORD中的表格内容回车行距过大无法调整行距
    红米ac2100路由器刷入openwrt教程
    STL中map介绍
    什么是 MyBatis?与 Hibernate 的区别
    gRPC入门学习之旅(六)
  • 原文地址:https://blog.csdn.net/weixin_37737254/article/details/125863392