• 论文阅读:Duplex Contextual Relation Network for Polyp Segmentation


    论文总体架构

    论文名称:用于息肉分割的双重上下文关系网络(ISBI2022)
    作者单位:北京邮电大学
    作者名称:尹子衿等
    代码地址: https://github.com/PRIS-CV/DCRNet/blob/master/lib/DCRNet.py

    摘要

    结肠镜检查中的息肉自动分割在结直肠癌(CRC)的早期诊断中起着关键作用。然而,息肉图像的多样性极大增加了准确分割的难度。现有的研究主要集中在学习单个图像中的上下文信息,但未能利用跨图像的息肉的同步视觉模式。本文从整个数据集的整体角度来探索上下文相关性,并提出了一个双工上下文关系网络(DCRNet)来捕获图像内和交叉图像之间上下文关系。基于上述两种相似性,每个输入区域的特征可以通过嵌入上下文区域来增强每个输入区域的特征。为了存储训练过程中先前图像嵌入的特征区域,设计了情景记忆并作为队列操作。我们在EndoScene、Kvasir-SEG和最近发布的大规模PICCOLO数据集上评估了所提出的方法。实验结果表明,我们提出的DCRNet在广泛使用的评价指标方面优于最先进的方法。

    贡献
    1、提出来嵌入上下文区域;
    2、设计了情景记忆并作为队列操作;
    3、提出了DCRNet;
    4、模型在多个结肠癌数据集上的表现良好。

    引言

    结肠癌的诊断和治疗中,对于息肉的区域分析是非常关键的步骤,切除息肉是预防和治疗早期结肠直肠癌的直接手段。结肠镜图像能够清晰地展示出整个患者结肠部分的信息,但是对于息肉的定位分割依然存在着以下困难:1、息肉多饰多样;2、息肉和结肠粘膜之间的边界过于模糊。如图所示:
    结肠癌图像示例
    从图像中我们能够观察到,有的比较明显,像 a b,肿起来的部分就是,而d就很夸张,c很不明显,不仔细看根本看不着。


    相关工作

    在现有的工作中,这里简介:
    1、多尺度提取特征的网络:ACSNet(MICCAI 2020),结合上下文信息和局部细节来应对息肉特征多样性的问题。
    PraNet使用多尺度的特征聚合的方法,根据局部特征提取轮廓图并通过上采样依次细化分割图。
    2、利用辅助信息来约束分割结果:SFANet(MICCAI 2019),利用区域边界约束,来选择特征聚合,提高分割精度。

    重点: 这些工作,额,好像都是在单个图像上找特征分割,这样的话是不是涉及到一个隐性的病灶相似度,然后选取对应的分割参数??如果是这样的话,一个模型所做到的工作就是在对于明显的病灶的分割的基础上,对于不同类型的息肉图像进行相应的隐形分类,简单图像简单分,复杂图像及不明显的图像就特殊方法,很有道理!
    所以本文就要提到一个机制,叫做情景记忆!

    理论证明:(Content-based medical image retrieval of ct images of
    liver lesions using manifold learning)已经证明了从其他图像中检索在放射学病变治疗过程中的意义。
    相关成果:在度量学习中已经有用到。
    所以,本文采用这种思想,从整个数据集的整体角度来探讨交叉图像和图像内的特征关联。

    工作总结:
    1、图像内上下文关系模块
    2、图像外上下文关系模块
    这两个模块也是即插即用的。

    模型结构

    先上图片

    DCRNet
    首先看到网络框架图,它由三部分组成,编码器、解码器、底部信息处理模块。
    编码解码器本文用到的是基于ResNet34的UNet,这里不再赘述。直接看重头戏!

    内部上下文关系

    class PAM_Module(Module):
        """ Position attention module"""
        #Ref from SAGAN
        def __init__(self, in_dim):
            super(PAM_Module, self).__init__()
            self.chanel_in = in_dim
            self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
            self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
            self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
            self.gamma = Parameter(torch.zeros(1))
    
            self.softmax = Softmax(dim=-1)
        def forward(self, x):
            """
                inputs :
                    x : input feature maps( B X C X H X W)
                returns :
                    out : attention value + input feature
                    attention: B X (HxW) X (HxW)
            """
            m_batchsize, C, height, width = x.size()
            proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
            proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
            energy = torch.bmm(proj_query, proj_key)
            attention = self.softmax(energy)
            proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
    
            out = torch.bmm(proj_value, attention.permute(0, 2, 1))
            out = out.view(m_batchsize, C, height, width)
    
            out = self.gamma*out + x
            return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    这一段代码,作者在里面写的备注还是非常详细的,这个东东的作用就是建立当前图像中所有像素点之间的关系,然后将这种关系与输入相乘,从而得到加权的效果!当然,残差结构一直是保留项目,嗯,就是这样的。

    外部上下文关系(这个平生还是第一次见,值得重点观察)

    class DCRNet(ResNet34Unet):
        def __init__(self,
                     bank_size=20,
                     num_classes=1,
                     num_channels=3,
                     is_deconv=False,
                     decoder_kernel_size=3,
                     pretrained=True,
                     feat_channels=512
                     ):
            super().__init__(num_classes=1,
                     num_channels=3,
                     is_deconv=False,
                     decoder_kernel_size=3,
                     pretrained=True)
            
            self.bank_size = bank_size
            self.register_buffer("bank_ptr", torch.zeros(1, dtype=torch.long))  # memory bank pointer
            self.register_buffer("bank", torch.zeros(self.bank_size, feat_channels, num_classes))  # memory bank
            self.bank_full = False
            
            # =====Attentive Cross Image Interaction==== #
            self.feat_channels = feat_channels
            self.L = nn.Conv2d(feat_channels, num_classes, 1)
            self.X = conv2d(feat_channels, 512, 3)
            self.phi = conv1d(512, 256)
            self.psi = conv1d(512, 256)
            self.delta = conv1d(512, 256)
            self.rho = conv1d(256, 512)
            self.g = conv2d(512 + 512, 512, 1)
            # =========Dual Attention========== #
            self.sa_head = PAM_Module(feat_channels)
            #=========Attention Fusion=========#
            self.fusion = nn.Conv2d(feat_channels, feat_channels, 1)
        #==Initiate the pointer of bank buffer==#
        def init(self):
            self.bank_ptr[0] = 0
            self.bank_full = False
            
        @torch.no_grad() #这句很重要!!!!
        def update_bank(self, x):
            ptr = int(self.bank_ptr)
            batch_size = x.shape[0]
            vacancy = self.bank_size - ptr
            if batch_size >= vacancy:
                self.bank_full = True
            pos = min(batch_size, vacancy)
            self.bank[ptr:ptr+pos] = x[0:pos].clone()
            # update pointer
            ptr = (ptr + pos) % self.bank_size
            self.bank_ptr[0] = ptr
            
        def down(self, x):
            e1 = self.encoder1(x)
            e2 = self.encoder2(e1)
            e3 = self.encoder3(e2)
            e4 = self.encoder4(e3)        
            return e4, e3, e2, e1
        
        def up(self, feat, e3, e2, e1, x):
            center = self.center(feat)
            d4 = self.decoder4(torch.cat([center, e3], 1))
            d3 = self.decoder3(torch.cat([d4, e2], 1))
            d2 = self.decoder2(torch.cat([d3, e1], 1))
            d1 = self.decoder1(torch.cat([d2, x], 1))
     
            f1 = self.finalconv1(d1)
            f2 = self.finalconv2(d2)
            f3 = self.finalconv3(d3)
            f4 = self.finalconv4(d4)
                    
            f4 = F.interpolate(f4, scale_factor=8, mode='bilinear', align_corners=True)
            f3 = F.interpolate(f3, scale_factor=4, mode='bilinear', align_corners=True)
            f2 = F.interpolate(f2, scale_factor=2, mode='bilinear', align_corners=True)
            
            return f4, f3, f2, f1
       
        def region_representation(self, input):
            X = self.X(input)
            L = self.L(input)
            aux_out = L
            batch, n_class, height, width = L.shape
            l_flat = L.view(batch, n_class, -1)
            # M = B * N * HW
            M = torch.softmax(l_flat, -1)
            channel = X.shape[1]
            # X_flat = B * C * HW
            X_flat = X.view(batch, channel, -1)
            # f_k = B * C * N
            f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2)
            return aux_out, f_k, X_flat, X
        
        def attentive_interaction(self, bank, X_flat, X):
            batch, n_class, height, width = X.shape
            # query = S * C
            query = self.phi(bank).squeeze(dim=2)
            # key: = B * C * HW
            key = self.psi(X_flat)
            # logit = HW * S * B (cross image relation)
            logit = torch.matmul(query, key).transpose(0,2)
            # attn = HW * S * B
            attn = torch.softmax(logit, 2) ##softmax维度要正确
            
            # delta = S * C
            delta = self.delta(bank).squeeze(dim=2)
            # attn_sum = B * C * HW
            attn_sum = torch.matmul(attn.transpose(1,2), delta).transpose(1,2)
            # x_obj = B * C * H * W
            X_obj = self.rho(attn_sum).view(batch, -1, height, width)
    
            concat = torch.cat([X, X_obj], 1)
            out = self.g(concat)
            return out
                
        def forward(self, x, flag='train'):
            batch_size = x.shape[0]
            #=== Stem ===#
            x = self.firstconv(x)
            x = self.firstbn(x)
            x = self.firstrelu(x)
            x_ = self.firstmaxpool(x)
     
            #=== Encoder ===#
            e4, e3, e2, e1  = self.down(x_)        
            #=== Attentive Cross Image Interaction ===#
            aux_out, patch, feats_flat, feats = self.region_representation(e4)
            if flag == 'train':
                self.update_bank(patch)
                ptr = int(self.bank_ptr)
                if self.bank_full == True:
                    feature_aug = self.attentive_interaction(self.bank, feats_flat, feats)
                else:
                    feature_aug = self.attentive_interaction(self.bank[0:ptr], feats_flat, feats)
            elif flag == 'test':
                feature_aug = self.attentive_interaction(patch, feats_flat, feats)
            #=== Dual Attention ===#
            sa_feat = self.sa_head(e4)
            #=== Fusion ===#
            feats = sa_feat + feature_aug
            #=== Decoder ===#
            f4, f3, f2, f1 = self.up(feats, e3, e2, e1, x)
            aux_out = F.interpolate(aux_out, scale_factor=32, mode='bilinear', align_corners=True)
            return aux_out, f4, f3, f2, f1
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145

    实验分析

    实验部分主要包含以下几个方面:

    数据集名称图像数量trainvalidtest
    EndoScene912548182182
    Kvasir-SEG1000600200200
    PICCOLO34332203897333

    设备学习率epochesbatchsizememory size
    NVIDIA RTX 2080Ti1e-4150420(Kvasir) / 40(E & P)

    可视化结果
    表格对比
    从可视化和表格数据上,我们能够看出本文模型的有效性!

    DRC推理时间,大小和效果

    对于这两个经典模型,有着不错的提高,说明了本模型的设计和内外上下文推理体系的合理性。

    讨论

    本文最大的亮点应该是外部memory 的设定,对于整个模型的体系架构,我们应当学习到这种内部隐性的分类思想和理念,所谓的外部上下文关系模块的机理也是如此!

    厚着脸皮,要个点赞收藏,谢谢支持!!!

  • 相关阅读:
    【机器学习】特征工程之特征选择
    MySQl有哪些索引(种类)?索引特点?为什么要使用索引?
    spark读取文件夹数据
    单片机——使用P3口流水点亮8位LED
    代码随想录算法训练营Day6 | 242.有效的字母异位词 ●349. 两个数组的交集 ● 202. 快乐数● 1. 两数之和
    代码随想录算法训练营第一天 | 704. 二分查找 | 27. 移除元素
    【英语:语法基础】B5.核心语法-句子成分和五种简单句
    AJAX介绍
    艾美捷1,2-二硬脂酰-sn-甘油-3-PC(DSPC)化学性质
    centos通过docker安装rabbitMq和延迟队列说明
  • 原文地址:https://blog.csdn.net/TTLoveYuYu/article/details/125425949