论文名称:用于息肉分割的双重上下文关系网络(ISBI2022)
作者单位:北京邮电大学
作者名称:尹子衿等
代码地址: https://github.com/PRIS-CV/DCRNet/blob/master/lib/DCRNet.py
结肠镜检查中的息肉自动分割在结直肠癌(CRC)的早期诊断中起着关键作用。然而,息肉图像的多样性极大增加了准确分割的难度。现有的研究主要集中在学习单个图像中的上下文信息,但未能利用跨图像的息肉的同步视觉模式。本文从整个数据集的整体角度来探索上下文相关性,并提出了一个双工上下文关系网络(DCRNet)来捕获图像内和交叉图像之间的上下文关系。基于上述两种相似性,每个输入区域的特征可以通过嵌入上下文区域来增强每个输入区域的特征。为了存储训练过程中先前图像嵌入的特征区域,设计了情景记忆并作为队列操作。我们在EndoScene、Kvasir-SEG和最近发布的大规模PICCOLO数据集上评估了所提出的方法。实验结果表明,我们提出的DCRNet在广泛使用的评价指标方面优于最先进的方法。
贡献:
1、提出来嵌入上下文区域;
2、设计了情景记忆并作为队列操作;
3、提出了DCRNet;
4、模型在多个结肠癌数据集上的表现良好。
结肠癌的诊断和治疗中,对于息肉的区域分析是非常关键的步骤,切除息肉是预防和治疗早期结肠直肠癌的直接手段。结肠镜图像能够清晰地展示出整个患者结肠部分的信息,但是对于息肉的定位分割依然存在着以下困难:1、息肉多饰多样;2、息肉和结肠粘膜之间的边界过于模糊。如图所示:
从图像中我们能够观察到,有的比较明显,像 a b,肿起来的部分就是,而d就很夸张,c很不明显,不仔细看根本看不着。
在现有的工作中,这里简介:
1、多尺度提取特征的网络:ACSNet(MICCAI 2020),结合上下文信息和局部细节来应对息肉特征多样性的问题。
PraNet使用多尺度的特征聚合的方法,根据局部特征提取轮廓图并通过上采样依次细化分割图。
2、利用辅助信息来约束分割结果:SFANet(MICCAI 2019),利用区域边界约束,来选择特征聚合,提高分割精度。
重点: 这些工作,额,好像都是在单个图像上找特征分割,这样的话是不是涉及到一个隐性的病灶相似度,然后选取对应的分割参数??如果是这样的话,一个模型所做到的工作就是在对于明显的病灶的分割的基础上,对于不同类型的息肉图像进行相应的隐形分类,简单图像简单分,复杂图像及不明显的图像就特殊方法,很有道理!
所以本文就要提到一个机制,叫做情景记忆!
理论证明:(Content-based medical image retrieval of ct images of
liver lesions using manifold learning)已经证明了从其他图像中检索在放射学病变治疗过程中的意义。
相关成果:在度量学习中已经有用到。
所以,本文采用这种思想,从整个数据集的整体角度来探讨交叉图像和图像内的特征关联。
首先看到网络框架图,它由三部分组成,编码器、解码器、底部信息处理模块。
编码解码器本文用到的是基于ResNet34的UNet,这里不再赘述。直接看重头戏!
class PAM_Module(Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
这一段代码,作者在里面写的备注还是非常详细的,这个东东的作用就是建立当前图像中所有像素点之间的关系,然后将这种关系与输入相乘,从而得到加权的效果!当然,残差结构一直是保留项目,嗯,就是这样的。
class DCRNet(ResNet34Unet):
def __init__(self,
bank_size=20,
num_classes=1,
num_channels=3,
is_deconv=False,
decoder_kernel_size=3,
pretrained=True,
feat_channels=512
):
super().__init__(num_classes=1,
num_channels=3,
is_deconv=False,
decoder_kernel_size=3,
pretrained=True)
self.bank_size = bank_size
self.register_buffer("bank_ptr", torch.zeros(1, dtype=torch.long)) # memory bank pointer
self.register_buffer("bank", torch.zeros(self.bank_size, feat_channels, num_classes)) # memory bank
self.bank_full = False
# =====Attentive Cross Image Interaction==== #
self.feat_channels = feat_channels
self.L = nn.Conv2d(feat_channels, num_classes, 1)
self.X = conv2d(feat_channels, 512, 3)
self.phi = conv1d(512, 256)
self.psi = conv1d(512, 256)
self.delta = conv1d(512, 256)
self.rho = conv1d(256, 512)
self.g = conv2d(512 + 512, 512, 1)
# =========Dual Attention========== #
self.sa_head = PAM_Module(feat_channels)
#=========Attention Fusion=========#
self.fusion = nn.Conv2d(feat_channels, feat_channels, 1)
#==Initiate the pointer of bank buffer==#
def init(self):
self.bank_ptr[0] = 0
self.bank_full = False
@torch.no_grad() #这句很重要!!!!
def update_bank(self, x):
ptr = int(self.bank_ptr)
batch_size = x.shape[0]
vacancy = self.bank_size - ptr
if batch_size >= vacancy:
self.bank_full = True
pos = min(batch_size, vacancy)
self.bank[ptr:ptr+pos] = x[0:pos].clone()
# update pointer
ptr = (ptr + pos) % self.bank_size
self.bank_ptr[0] = ptr
def down(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
return e4, e3, e2, e1
def up(self, feat, e3, e2, e1, x):
center = self.center(feat)
d4 = self.decoder4(torch.cat([center, e3], 1))
d3 = self.decoder3(torch.cat([d4, e2], 1))
d2 = self.decoder2(torch.cat([d3, e1], 1))
d1 = self.decoder1(torch.cat([d2, x], 1))
f1 = self.finalconv1(d1)
f2 = self.finalconv2(d2)
f3 = self.finalconv3(d3)
f4 = self.finalconv4(d4)
f4 = F.interpolate(f4, scale_factor=8, mode='bilinear', align_corners=True)
f3 = F.interpolate(f3, scale_factor=4, mode='bilinear', align_corners=True)
f2 = F.interpolate(f2, scale_factor=2, mode='bilinear', align_corners=True)
return f4, f3, f2, f1
def region_representation(self, input):
X = self.X(input)
L = self.L(input)
aux_out = L
batch, n_class, height, width = L.shape
l_flat = L.view(batch, n_class, -1)
# M = B * N * HW
M = torch.softmax(l_flat, -1)
channel = X.shape[1]
# X_flat = B * C * HW
X_flat = X.view(batch, channel, -1)
# f_k = B * C * N
f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2)
return aux_out, f_k, X_flat, X
def attentive_interaction(self, bank, X_flat, X):
batch, n_class, height, width = X.shape
# query = S * C
query = self.phi(bank).squeeze(dim=2)
# key: = B * C * HW
key = self.psi(X_flat)
# logit = HW * S * B (cross image relation)
logit = torch.matmul(query, key).transpose(0,2)
# attn = HW * S * B
attn = torch.softmax(logit, 2) ##softmax维度要正确
# delta = S * C
delta = self.delta(bank).squeeze(dim=2)
# attn_sum = B * C * HW
attn_sum = torch.matmul(attn.transpose(1,2), delta).transpose(1,2)
# x_obj = B * C * H * W
X_obj = self.rho(attn_sum).view(batch, -1, height, width)
concat = torch.cat([X, X_obj], 1)
out = self.g(concat)
return out
def forward(self, x, flag='train'):
batch_size = x.shape[0]
#=== Stem ===#
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x_ = self.firstmaxpool(x)
#=== Encoder ===#
e4, e3, e2, e1 = self.down(x_)
#=== Attentive Cross Image Interaction ===#
aux_out, patch, feats_flat, feats = self.region_representation(e4)
if flag == 'train':
self.update_bank(patch)
ptr = int(self.bank_ptr)
if self.bank_full == True:
feature_aug = self.attentive_interaction(self.bank, feats_flat, feats)
else:
feature_aug = self.attentive_interaction(self.bank[0:ptr], feats_flat, feats)
elif flag == 'test':
feature_aug = self.attentive_interaction(patch, feats_flat, feats)
#=== Dual Attention ===#
sa_feat = self.sa_head(e4)
#=== Fusion ===#
feats = sa_feat + feature_aug
#=== Decoder ===#
f4, f3, f2, f1 = self.up(feats, e3, e2, e1, x)
aux_out = F.interpolate(aux_out, scale_factor=32, mode='bilinear', align_corners=True)
return aux_out, f4, f3, f2, f1
实验部分主要包含以下几个方面:
数据集名称 | 图像数量 | train | valid | test |
---|---|---|---|---|
EndoScene | 912 | 548 | 182 | 182 |
Kvasir-SEG | 1000 | 600 | 200 | 200 |
PICCOLO | 3433 | 2203 | 897 | 333 |
设备 | 学习率 | epoches | batchsize | memory size |
---|---|---|---|---|
NVIDIA RTX 2080Ti | 1e-4 | 150 | 4 | 20(Kvasir) / 40(E & P) |
从可视化和表格数据上,我们能够看出本文模型的有效性!
对于这两个经典模型,有着不错的提高,说明了本模型的设计和内外上下文推理体系的合理性。
本文最大的亮点应该是外部memory 的设定,对于整个模型的体系架构,我们应当学习到这种内部隐性的分类思想和理念,所谓的外部上下文关系模块的机理也是如此!