• 【论文阅读】Directional Connectivity-based Segmentation of Medical Images


    论文:Directional Connectivity-based Segmentation of Medical Images
    代码:https://github.com/zyun-y/dconnnet

    摘要

    出发点:生物标志分割中的解剖学一致性对许多医学图像分析任务至关重要。
    之前工作的问题:以往的连通性工作忽略了潜在空间中丰富的信道方向的信息。
    证明:有效地将方向子空间从共享潜在空间中解耦可以显著增强基于连通性网络中的特征表示
    提出:一种用于分割的定向连通性建模方案,该方案解耦、跟踪和利用跨网络的方向信息。

    介绍

    在这里插入图片描述
    介绍了基于像素分类和基于连通性的模型之间潜在的空间差异。前者仅突出分类特征,eg:边界。后者包含方向信息,例如:边界像素之间的水平连接。
    在这里插入图片描述
    将两组潜在特征(范畴性和方向性)在DconnNet的潜在空间中的流向用T - SNE进行可视化。它们先被解缠,然后在一个投影的共享流形中有效地融合,基于聚类的结果进行颜色的渲染。
    在这里插入图片描述
    这个是普通的分割掩码变成连通性掩码的示意图。每一个原来一个像素的位置包含了周围8个像素的mask值。一个一个对应即可。感觉好像这个图中间那个错了,中间像素的C1是positive。

    方法

    由于不同像素类别和方向之间的连通性,基于连通性的网络的潜在空间中存在两组特征:类别信息和方向信息。每一组特征在隐空间中形成其特定的子空间。两个子空间是高度耦合的。我们证明了方向空间的有效解缠和有效利用可以增强连通性模型中的整体特征表示。
    Pretrained ResNet:提取特征。
    SDE:特征信息和方向信息解耦。
    IFD:特征信息与方向信息融合。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    效果

    在这里插入图片描述
    在这里插入图片描述
    利用T - SNE对DconnNet在SDE模块前后的隐通道嵌入进行可视化。( b )中的颜色表示无监督聚类结果。当应用于SDE时,通道嵌入自然地分组为几个不同的部分。

    结论

    其核心思想是将方向子空间从共享的潜在空间中解耦出来,并利用提取的方向特征来增强整体的数据表示。

    1. 通过与其他先进方法的统计比较,显示了DconnNet的整体性能更好。
    2. 通过在一个拓扑敏感的数据集上定性和定量地将DconnNet与其他方法进行比较,展示了其保留拓扑结构的能力。
    3. 通过对DconnNet的隐空间进行可视化,揭示了方向子空间的解纠缠过程

    跑通代码

    数据集方面

    #作者的数据读取中是读取的3通道图像和二值的mask图像,我们写一个数据集读取的函数能让他输出读取的3通道图像和二值的mask图像变成的tensor就可以。自己可以加一点数据增强。
    #root_path 是数据集的地址,fold_json存储了10折的图片的名称,fold_num是选取哪一个折作为验证集,就是十折交叉验证的内容,image_size最后resize的图片大小,mode是训练还是验证,augmentation_prob就是数据增强的概率	
    	#大家按照自己的数据集,写一个数据集的函数,之后跑别人代码的时候直接用就可以
    import os
    import random
    from random import shuffle
    import numpy as np
    import torch
    from torch.utils import data
    from torchvision import transforms as T
    from torchvision.transforms import functional as F
    from PIL import Image
    import numpy as np
    import json
    
    from .GetDataset_CHASE import connectivity_matrix
    class ImageFolder(data.Dataset):
    	def __init__(self, image_root,label_root,json_path, fold=1, image_size=400, mode='train', augmentation_prob=0.4):
    		"""Initializes image paths and preprocessing module."""
    		self.root = image_root
    		with open(json_path, 'r') as load_f:
    			self.fold_data = json.load(load_f)
    		self.data_list=[]
    		if mode == 'train':
    			for i in range(1, 11):
    				if i != fold:
    					self.data_list += self.fold_data['Fold ' + str(i)]
    		elif mode == 'val':
    			self.data_list = self.fold_data['Fold ' + str(fold)]
    		else:
    			raise ValueError("数据类型只有train和val")
    		self.image_size = image_size
    		self.label_root = label_root
    		self.mode = mode
    		# self.RotationDegree = [0,90,180,270]
    		self.augmentation_prob = augmentation_prob
    		print("image count in {} path :{}".format(self.mode,len(self.data_list)))
    
    	def __getitem__(self, index):
    		"""Reads an image from a file and preprocesses it and returns."""
    		image_path = os.path.join(self.root,self.data_list[index])
    		GT_path = os.path.join(self.label_root ,self.data_list[index])
    
    		image = Image.open(image_path).convert('RGB')
    		GT = Image.open(GT_path).convert('1')
    
    
    
    		aspect_ratio = image.size[0]/image.size[1]#weight/height
    
    
    		Transform = []
    
    		ResizeRange = random.randint(500,525)
    		Transform.append(T.Resize((ResizeRange,int(ResizeRange*aspect_ratio))))
    		p_transform = random.random()
    
    		if (self.mode == 'train') and p_transform <= self.augmentation_prob:
    
    			RotationRange = random.randint(-10,10)
    			Transform.append(T.RandomRotation((RotationRange,RotationRange)))
    			CropRange = random.randint(500,525)
    			Transform.append(T.CenterCrop((CropRange,int(CropRange*aspect_ratio))))
    			Transform = T.Compose(Transform)
    			
    			image = Transform(image)
    			GT = Transform(GT)
    
    
    			ShiftRange_left = random.randint(0,20)
    			ShiftRange_upper = random.randint(0,20)
    			ShiftRange_right = image.size[0] - random.randint(0,20)
    			ShiftRange_lower = image.size[1] - random.randint(0,20)
    			image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
    			GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
    			#
    			# if random.random() < 0.5:
    			# 	image = F.hflip(image)
    			# 	GT = F.hflip(GT)
    			#
    			# if random.random() < 0.5:
    			# 	image = F.vflip(image)
    			# 	GT = F.vflip(GT)
    
    			Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02)
    
    			image = Transform(image)
    
    			Transform =[]
    
    
    		Transform.append(T.Resize([256,256]))
    		Transform.append(T.ToTensor())
    		Transform = T.Compose(Transform)
    		
    		image = Transform(image)
    		GT = Transform(GT)
    
    		mean = [0.1591, 0.1591, 0.1591]
    		std = [0.2593, 0.2593, 0.2593]
    		Norm_ = T.Normalize(mean, std)
    		image = Norm_(image)
    		# images = image
    		# image = torch.unsqueeze(image,0)
    		# images = torch.cat([images,image],dim=1)
    
    
    
    		return image, GT
    
    	def __len__(self):
    		"""Returns the total number of font files."""
    		return len(self.data_list)
    
    def get_loader(image_path,label_path,image_size, batch_size, json_path, fold =None,num_workers=2, mode='train',augmentation_prob=0.4):
    	"""Builds and returns Dataloader."""
    	
    	dataset = ImageFolder(image_root = image_path,label_root=label_path,json_path=json_path, fold = fold,image_size =image_size, mode=mode,augmentation_prob=augmentation_prob)
    	data_loader = data.DataLoader(dataset=dataset,
    								  batch_size=batch_size,
    								  shuffle=True,
    								  num_workers=num_workers)
    	return data_loader
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124

    main函数修改

    我直接在函数上写上自己数据集的地址和相关参数了由于时间原因,大家可以把它加入到args参数里面更规范。我是class=1的任务,所以一定把class修改了,源代码中是4.

    def main(args):
    
        ## K-fold cross validation ##
        for exp_id in range(args.folds):
        # 
            train_loader = get_loader(image_path='自己数据集中image的位置',
                                      label_path='自己数据集中mask的位置',
                                      json_path="json文件的位置",
                                      image_size=(256,256),
                                      batch_size=1,
                                      fold=1,
                                      num_workers=8,
                                      mode='train',
                                      augmentation_prob=0.4)
            val_loader = get_loader(image_path='自己数据集中image的位置',
                                    label_path='自己数据集中mask的位置',
                                    json_path="json文件的位置",
                                    image_size=(256,256),
                                    batch_size=1,
                                    fold=1,
                                    num_workers=8,
                                    mode='val',
                                    augmentation_prob=0.)
    
    
            print("Train batch number: %i" % len(train_loader))
            print("Test batch number: %i" % len(val_loader))
    
            #### Above: define how you get the data on your own dataset ######
            model = DconnNet(num_class=1).cuda()
    
            if args.pretrained:
                model.load_state_dict(torch.load(args.pretrained,map_location = torch.device('cpu')))
                model = model.cuda()
    
            solver = Solver(args)
    
            solver.train(model, train_loader, val_loader,exp_id+1, num_epochs=args.epochs)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    connect_loss.py

    我是一个类别的,运行一直报错,是connect_loss.py这个函数它最后有个conn = conn.squeeze()注释掉就可以运行了。

    def connectivity_matrix(multimask, class_num):
    
        ##### converting segmentation masks to connectivity masks ####
    
        [batch,_,rows, cols] = multimask.shape
        # batch = 1
        conn = torch.zeros([batch,class_num*8,rows, cols]).cuda()
        for i in range(class_num):
            mask = multimask[:,i,:,:]
            # print(mask.shape)
            up = torch.zeros([batch,rows, cols]).cuda()#move the orignal mask to up
            down = torch.zeros([batch,rows, cols]).cuda()
            left = torch.zeros([batch,rows, cols]).cuda()
            right = torch.zeros([batch,rows, cols]).cuda()
            up_left = torch.zeros([batch,rows, cols]).cuda()
            up_right = torch.zeros([batch,rows, cols]).cuda()
            down_left = torch.zeros([batch,rows, cols]).cuda()
            down_right = torch.zeros([batch,rows, cols]).cuda()
    
    
            up[:,:rows-1, :] = mask[:,1:rows,:]
            down[:,1:rows,:] = mask[:,0:rows-1,:]
            left[:,:,:cols-1] = mask[:,:,1:cols]
            right[:,:,1:cols] = mask[:,:,:cols-1]
            up_left[:,0:rows-1,0:cols-1] = mask[:,1:rows,1:cols]
            up_right[:,0:rows-1,1:cols] = mask[:,1:rows,0:cols-1]
            down_left[:,1:rows,0:cols-1] = mask[:,0:rows-1,1:cols]
            down_right[:,1:rows,1:cols] = mask[:,0:rows-1,0:cols-1]
    
            conn[:,(i*8)+0,:,:] = mask*down_right
            conn[:,(i*8)+1,:,:] = mask*down
            conn[:,(i*8)+2,:,:] = mask*down_left
            conn[:,(i*8)+3,:,:] = mask*right
            conn[:,(i*8)+4,:,:] = mask*left
            conn[:,(i*8)+5,:,:] = mask*up_right
            conn[:,(i*8)+6,:,:] = mask*up
            conn[:,(i*8)+7,:,:] = mask*up_left
    
        conn = conn.float()
        # conn = conn.squeeze()
        # print(conn.shape)
        return conn
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
  • 相关阅读:
    Android->layer-list画对号画叉号画箭头画进度条
    如何制作一个卡刷扩容补丁。用于扩容系统等分区 刷写第三方需要扩容才可以刷写的系统或者GSI GSI系统bug修复【二】
    力扣二分篇
    纸浆暴力反弹——复制去年走势,铁矿石认购2-4倍,双硅价差再度翘尾?2022.6.28
    索引数据结构详解
    Jekyll如何自定义摘要
    计算机毕业设计php_thinphp_vue的约课管理系统-课程预约(源码+系统+mysql数据库+Lw文档)
    网络编程01
    在Kibana中使用Discover来制作表格table
    Golang如何使用命令行-- flag库
  • 原文地址:https://blog.csdn.net/goodenough5/article/details/133142917