• 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割1(综述篇)


    在上一个关于3D 目标的任务,是基于普通CNN网络的3D分类任务。在这个任务中,分类数据采用的是CT结节的LIDC-IDRI数据集,其中对结节的良恶性、毛刺、分叶征等等特征进行了各自的等级分类。感兴趣的可以直接点击下方的链接,直达学习:

    1. 【3D图像分类】基于Pytorch的3D立体图像分类1(基础篇)
    2. 【3D图像分类】基于Pytorch的3D立体图像分类2(数据增强篇)

    在开始本次关于3D 目标的分割任务前呢,我还是建议先去看看上述较为简单的分类任务,毕竟大多数是相似的,有很高的借鉴意义。

    一、导言

    准备一个训练,需要下面这些内容组成:

    1. 准备数据
    2. 准备网络
    3. 搭建训练主模型
      • train one epoch
      • valid one epoch
      • 存储模型
      • 存储指标
    4. loss 函数
    5. dice coeff 评估指标
    6. optimizer优化方式

    其中,在本项目中:

    1. 网络采用vnet 3d模型
    2. 数据采用patch裁剪大小
    3. loss函数未dice loss
    4. 评价指标是dice coeff
    5. optimizer优化方式是SGD

    二、搭建主结构

    训练的主体结构(骨架),总数包括几个部分:

    1. config:可调参数定义,包括数据路径、图像大小、类别数量、学习率、batch size等等;
    2. main:主函数,包括:
      • 构建模型
      • 构建数据
      • 优化器
      • 学习率变化方式
      • 损失函数
      • 评估指标
      • 训练batch循环
      • 验证batch循环
    3. 后处理:包括模型参数存储,指标走势绘图等等。

    上面这些个内容,基本上是囊括了深度学习模型训练的整体结构了,后面的工作就是对每一部分进行补充。就犹如已经有了骨架,后续就是补充肉身了。

    后面给出的这个pytorch骨架案例,也是后面再构建训练任务,一个可以参考的依据,可收藏。

    2.1、导入库和配置参数

    import os
    import matplotlib.pyplot as plt
    import torch.utils.data
    import torch.optim as optim
    
    from datasets.datasets import myDataset
    
    os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"  # 使用gpu0
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 没gpu就用cpu
    print(DEVICE)
    
    ############################################################
    # Configuration
    ############################################################
    class Configuration(object):
        train_path = r"./database/sk_output/train"
        valid_path = r"./database/sk_output/valid"
        model_path = r'./checkpoints'
    
        Crop_Size = (48, 96, 96)
        num_outs = 2
    
        Batch_Train = 32
        Batch_Test = 16
        Max_epoch = 220
        Num_Workers = 8
    
        Dice_Best = 0
        LR = 0.0003
    
        momentum = 0.99
        weight_decay = 1e-8
    
        def display(self):
            """Display Configuration values."""
            print("\nConfigurations:")
            print("")
            for a in dir(self):
                if not a.startswith("__") and not callable(getattr(self, a)):
                    print("{:30} {}".format(a, getattr(self, a)))
            print("\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    2.2、构建main主函数

    def main():
        Config = Configuration()
        Config.display()
    
        train_loader, valid_loader = get_Dataloader(Config)
    
        model = get_model(Config).to(DEVICE)
    
        # ---- OPTIMIZER ----
        optimizer = optim.SGD(model.parameters(), lr=Config.LR, momentum=Config.momentum, weight_decay=Config.weight_decay)
    
        train_loss_list = []  # 用来记录训练损失
        valid_loss_list = []  # 用来记录验证损失
        valid_dice_list = []
    
        epoch_list = []
        for epoch in range(1, Config.Max_epoch + 1):
            epoch_list.append(epoch)
            train_loss = train_model(model, DEVICE, train_loader, optimizer, epoch)  # 训练
            valid_loss, valid_dice = valid_model(model, DEVICE, valid_loader, epoch)  # 验证
    
            train_loss_list.append(train_loss)
            valid_loss_list.append(valid_loss)
            valid_dice_list.append(valid_dice)
    
            draw_plot(epoch_list, valid_dice_list, 'valid_dice')
            draw_plot(epoch_list, valid_loss_list, 'valid_loss')
            draw_plot(epoch_list, train_loss_list, 'train_loss')
    
            if valid_dice > Config.Dice_Best:
                path_ckpt = os.path.join(Config.model_path, 'best_model.pth')
                save_model(path_ckpt, model)
                Config.Dice_Best = valid_dice
            else:
                path_ckpt = os.path.join(Config.model_path, 'last_model.pth')
                save_model(path_ckpt, model)
    
        print('best val Dice is ', Config.Dice_Best)
    
    if __name__ == '__main__':
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    2.3、构建获取模型和数据的函数

    def get_model(config):
        from models.vnet3d import VNet3D
        model = VNet3D(num_outs=config.num_outs, channels=16)
    
        model = model.to(DEVICE)  # 模型部署到gpu或cpu里
        model = torch.nn.DataParallel(model).to(DEVICE)
        return model
    
    def get_Dataloader(config):
    	# get train data
        dataset_train = myDataset(config.train_path, config.Crop_Size, isTrain=True)
        print(len(dataset_train))
        train_loader = torch.utils.data.DataLoader(dataset_train,
                                                   batch_size=config.Batch_Train, shuffle=True,
                                                   num_workers=config.Num_Workers, drop_last=False)
    	# get valid data
        dataset_valid = myDataset(config.valid_path, config.Crop_Size, isTrain=False)
        valid_loader = torch.utils.data.DataLoader(dataset_valid,
                                                   batch_size=config.Batch_Test, shuffle=False,
                                                   num_workers=config.Num_Workers, drop_last=False)
    
        return train_loader, valid_loader
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    2.4、构建训练循环和验证循环

    def train_model(model, device, train_loader, optimizer, epoch):
        config = Configuration()
        model.train()
    
        for batch_index, (data, target) in enumerate(train_loader):  # 取batch索引,(data,target),也就是图和标签
            data, target = data.to(device), target.to(device)
    
            output = model(data)
            loss = Loss(output, target)
    
            optimizer.zero_grad()  # 梯度归零
            loss.backward()  # 反向传播
            optimizer.step()  # 优化器走一步
    
        return losses.avg  # 返回平均损失,损失列表
    
    
    def valid_model(model, device, test_loader, epoch):
        config = Configuration()
        model.eval()
    
        with torch.no_grad():  # 不进行 梯度计算(反向传播)
            for batch_index, (data, target) in enumerate(test_loader):  # 枚举batch索引,(图,标签)
                data, target = data.to(device), target.to(device)
    
                output = model(data)
                loss = Loss(output, target)  # 计算损失
    
        return losses.avg, multi_dices.avg
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    2.5、后处理

    保存模型的参数,和绘制训练过程中train loss、valid loss,以及valid dice走势图,如下:

    def draw_plot(x_list, y_list, title_name):
        plt.plot(x_list, y_list, label=title_name)
    
        plt.xlabel('x', fontsize=15)
        plt.ylabel('y', fontsize=15)
        plt.title(title_name, fontsize=15)
        plt.savefig('./logs/cure.png')
    
    
    def save_model(path, model):
        if isinstance(model, torch.nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        torch.save(state_dict, path)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    至此,每一个模块都有了对应的归宿,后面就是如何将缺漏的地方,补全过程了。反倒是这部分的代码相对较少,两大需要单独验证的数据和模型是大头,其他就好办了。

    三、总结

    本文是关于PytorchVNet 3D 图像分割的第一篇,也就是一个综述篇,主要是对这个项目的任务目的,以及其中的一个流程进行了梳理。

    上述的骨干代码还不能够作为训练使用,还需要补充进去骨肉,才能够适应不同的任务,这一块的内容将会在后面的几个篇章中,一一陈述。

    如果你也在做类似的事情,欢迎点赞、收藏,mark住。对于这部分的内容可以一起交流,欢迎多多评论。

  • 相关阅读:
    【Java杂谈】#1 【MCA JAVA后端架构师】
    运行python进行指定内容的文件名查找
    servlet映射路径匹配解析
    超好用的A-level/IGCSE学习网站
    C++基础知识(五)--- 智能指针类&字符串类
    DOM系列之 click 延时解决方案
    《苍穹外卖》知识梳理6-缓存商品,购物车功能
    探索Java生态系统的其他技术与工具
    Item-Based Recommendations with Hadoop
    解读最早的草图-图像翻译工作SketchyGAN
  • 原文地址:https://blog.csdn.net/wsLJQian/article/details/133966815