• DDANet: Dual Decoder Attention Network for Automatic Polyp Segmentation


    双decoder用于息肉分割。文章的创新点在与使用了双分支的decoder,单encoder的结构。decoder的第二个分支会产生注意力map,在代码中体现为输出通道为1。这个和之前看的confidence map很像。
    看一下文章的结构图:
    在这里插入图片描述
    在decoder中,第二个分支生成注意力图,其实shared encoder的跳连接上下两个是一样的,在代码中可以看到,稍后分析。
    encoder,decoder的构成:
    在这里插入图片描述
    他这里使用的encoder不是原始的resnet,但是使用了resnet的思想,且在两个3x3卷积之后,加入了通道注意力,这个在ESANet中RGB和Depth融合方法中也有用到。
    在decoder使用的4倍转置卷积,和encoder的特征进行concat,和RedNet的跳连接结构,和上采样结构都很像。
    实验:
    医学图像的数据集,不太了解。
    ------------------------------------------------------分割线-------------------------------------------------------------------------------------------------------------------------
    代码:

    
    import torch
    import torch.nn as nn
    import torchvision.models as models
    
    class SELayer(nn.Module):
        def __init__(self, channel, reduction=16):
            super().__init__()
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.fc = nn.Sequential(
                nn.Linear(channel, int(channel / reduction), bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(int(channel / reduction), channel, bias=False),
                nn.Sigmoid()
            )
    
        def forward(self, x):
            b, c, _, _ = x.size()
            y = self.avg_pool(x).view(b, c)
            y = self.fc(y).view(b, c, 1, 1)
            return x * y.expand_as(x)
    
    class ResidualBlock(nn.Module):
        def __init__(self, in_c, out_c):
            super(ResidualBlock, self).__init__()
    
            self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
            self.bn1 = nn.BatchNorm2d(out_c)
    
            self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm2d(out_c)
    
            self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)
            self.bn3 = nn.BatchNorm2d(out_c)
            self.se = SELayer(out_c)
    
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            x1 = self.conv1(x)
            x1 = self.bn1(x1)
            x1 = self.relu(x1)
    
            x2 = self.conv2(x1)
            x2 = self.bn2(x2)
    
            x3 = self.conv3(x)
            x3 = self.bn3(x3)
            x3 = self.se(x3)
    
            x4 = x2 + x3
            x4 = self.relu(x4)
    
            return x4
    
    class EncoderBlock(nn.Module):
        def __init__(self, in_c, out_c):
            super(EncoderBlock, self).__init__()
    
            self.r1 = ResidualBlock(in_c, out_c)
            self.r2 = ResidualBlock(out_c, out_c)
            self.pool = nn.MaxPool2d(2, stride=2)
    
        def forward(self, x):
            x = self.r1(x)
            x = self.r2(x)
            p = self.pool(x)
    
            return x, p
    
    class DecoderBlock(nn.Module):
        def __init__(self, in_c, out_c):
            super(DecoderBlock, self).__init__()
    
            self.upsample = nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1)
            self.r1 = ResidualBlock(in_c+out_c, out_c)
            self.r2 = ResidualBlock(out_c, out_c)
    
        def forward(self, x, s):
            x = self.upsample(x)
            x = torch.cat([x, s], axis=1)
            x = self.r1(x)
            x = self.r2(x)
    
            return x
    
    
    class CompNet(nn.Module):
        def __init__(self):
            super(CompNet, self).__init__()
    
            """ Shared Encoder """
            self.e1 = EncoderBlock(3, 32)
            self.e2 = EncoderBlock(32, 64)
            self.e3 = EncoderBlock(64, 128)
            self.e4 = EncoderBlock(128, 256)
    
            """ Decoder: Segmentation """
            self.s1 = DecoderBlock(256, 128)
            self.s2 = DecoderBlock(128, 64)
            self.s3 = DecoderBlock(64, 32)
            self.s4 = DecoderBlock(32, 16)
    
            """ Decoder: Autoencoder """
            self.a1 = DecoderBlock(256, 128)
            self.a2 = DecoderBlock(128, 64)
            self.a3 = DecoderBlock(64, 32)
            self.a4 = DecoderBlock(32, 16)
    
            """ Autoencoder attention map """
            self.m1 = nn.Sequential(
                nn.Conv2d(128, 1, kernel_size=1, padding=0),
                nn.Sigmoid()
            )
            self.m2 = nn.Sequential(
                nn.Conv2d(64, 1, kernel_size=1, padding=0),
                nn.Sigmoid()
            )
            self.m3 = nn.Sequential(
                nn.Conv2d(32, 1, kernel_size=1, padding=0),
                nn.Sigmoid()
            )
            self.m4 = nn.Sequential(
                nn.Conv2d(16, 1, kernel_size=1, padding=0),
                nn.Sigmoid()
            )
    
            """ Output """
            self.output1 = nn.Conv2d(16, 1, kernel_size=1, padding=0)
            self.output2 = nn.Conv2d(16, 1, kernel_size=1, padding=0)
    
        def forward(self, x):
            """ Encoder """
            x1, p1 = self.e1(x)
            x2, p2 = self.e2(p1)
            x3, p3 = self.e3(p2)
            x4, p4 = self.e4(p3)
    
            """ Decoder 1 """
            s1 = self.s1(p4, x4)
            a1 = self.a1(p4, x4)
            m1 = self.m1(a1)
            x5 = s1 * m1
    
            """ Decoder 2 """
            s2 = self.s2(x5, x3)
            a2 = self.a2(a1, x3)
            m2 = self.m2(a2)
            x6 = s2 * m2
    
            """ Decoder 3 """
            s3 = self.s3(x6, x2)
            a3 = self.a3(a2, x2)
            m3 = self.m3(a3)
            x7 = s3 * m3
    
            """ Decoder 4 """
            s4 = self.s4(x7, x1)
            a4 = self.a4(a3, x1)
            m4 = self.m4(a4)
            x8 = s4 * m4
    
            """ Output """
            out1 = self.output1(x8)
            out2 = self.output2(a4)
    
            return out1, out2
    
    if __name__ == "__main__":
        x = torch.rand((1, 3, 512, 512))
    
        model = CompNet()
        y1, y2 = model.forward(x)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174

    我们直接看forward函数:
    1:首先就是输入的x经过四个encoder block
    在这里插入图片描述在这里插入图片描述
    在这里插入图片描述
    每个block中包含的residual block:

    
    class ResidualBlock(nn.Module):
        def __init__(self, in_c, out_c):
            super(ResidualBlock, self).__init__()
    
            self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
            self.bn1 = nn.BatchNorm2d(out_c)
    
            self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm2d(out_c)
    
            self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)
            self.bn3 = nn.BatchNorm2d(out_c)
            self.se = SELayer(out_c)
    
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            x1 = self.conv1(x)
            x1 = self.bn1(x1)
            x1 = self.relu(x1)
    
            x2 = self.conv2(x1)
            x2 = self.bn2(x2)
    
            x3 = self.conv3(x)
            x3 = self.bn3(x3)
            x3 = self.se(x3)
    
            x4 = x2 + x3
            x4 = self.relu(x4)
    
            return x4
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    以第一个为例:输入通道为3,输出通道为32.
    经过两个3x3卷积,然后一个1x1卷积的残差连接,注意,这里经过1x1卷积之后才经过SElayer。和图中画的有些不同。
    SELayer:和之前看的通道注意力一样,首先经过平均池化,然后经过两个线性层,最后与原始的x相乘得到最终的结果。
    在这里插入图片描述
    每个encoder block包含两个残差块,一个2x2最大池化。注意这里返回的是x,用来跳连接的。 x, p对应于x1, p1 。同理x2为p1。encoder结束后,两个输出分别经过两个decoder分支。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在decoder block中首先通过转置卷积进行上采样四倍,然后和跳连接相concat,再经过两个残差块。而另一条分支的处理和这一条一样的。
    在这里插入图片描述
    不同的是第二条分支产生的结果是一个注意力图。通过一个卷积生成通道为1的attention map,这里医学分割图的最终结果就是单通道,如果是其他的数据集即时多通道,即这里其实相当于之前说的置信度图。
    在这里插入图片描述
    生成的置信度图与第一条分支产生的结果相乘,这样执行四次。
    在这里插入图片描述
    最终的输出经过两个通道为1的卷积,即最终的分割图。
    在这里插入图片描述
    整个网络框架可以简化为:
    在这里插入图片描述

  • 相关阅读:
    【目录】Java程序设计课程学习导航(更新中)
    C++PrimerPlus(第6版)中文版:Chapter16.5.1函数对象_函数符概念
    某60区块链安全之不安全的随机数实战一
    Fasttext解读(1)
    越细粒度的锁越好吗?产生死锁怎么办?
    详解项目中使用dotPeek调试源码
    配电房环境智能监控系统:守护电力设施,保障安全运行
    Spring源码解析(十):spring整合mybatis源码
    克隆的虚拟机,查不到IP号
    类与对象(二)----对象详解
  • 原文地址:https://blog.csdn.net/qq_43733107/article/details/127979946