• UNet 网络做图像分割DRIVE数据集


    目录

    1. 介绍

    2. 搭建 UNet 网络

    3. dataset 数据加载

    4. train 训练网络

    5. predict 分割图像

    6. show

    7. 完整代码


    1. 介绍

    完整项目下载地址:UNet 网络做图像分割DRIVE数据集

    项目的目录如下所示

    1. DRIVE 存放的是数据集
    2. predict 是待分割的图像
    3. result 里面放分割predict 的结果
    4. dataset 是处理数据的文件、model存放unet网络、predict是预测、train是网络的训练、UNet.pth 是训练好的权重文件

    之前做了一个图像分割的例子,里面大部分的代码和本篇的内容重合,所以每个脚本的代码只会做简单的介绍。具体的可以参考之前的内容,这里给出链接:

    model :  UNet - unet网络

    dataset :UNet - 数据加载 Dataset

    train : UNet - 训练数据train

    predict : UNet - 预测数据predict(多个图像的分割)

    DRIVE ( Digital Retinal Images for Vessel Extraction ):用于血管提取的数字视网膜图像

    训练样本:灰度图像

     对应的标签:二值图像

    因为这个分割项目完成几周了,最近才整理。所以,原数据集 DRIVE 可能是彩色图像 + mask 掩膜(具体的记不清了)

    • 这里没有使用 mask 
    • 如果是彩色图像的话,在生成unet网络的时候,传入的channel设置成3就行了。或者想用灰度图像的形式,要么用opencv转一下,可以看见灰度化的效果类似于展示的那样;要么在预处理的里面转成灰度图片 transform.Grayscale()

    2. 搭建 UNet 网络

    和之前unet网络不同的是,这里通过填充size,可以保证任意图像维度的输入

    之前的代码需要经过4此下采样,每次维度扩展,size减半,所以需要保证输入图像的大小是 2的4次方

    具体这块怎么实现我也看不懂,经过测试,可以实现任意输入的size

    3. dataset 数据加载

    数据加载的时候,将图像的预处理也放到了这里

    这里训练的图像要 ToTensor ,归一化+改变通道顺序+转为tensor等等。同时,为了加快训练,对图像正规化,因为训练的图像是灰度图,所以只需要单通道的均值和标准差


    然后是 数据加载 的初始化

    这里的imgs里面的内容是,传入路径root下的图像路径,这里是:

    ['01.png', '02.png', '03.png', '04.png', '05.png', '06.png', '07.png', '08.png', '09.png']

    self.imgs 是将root 路径和root 里面每个图像的路径 拼接在一块的路径,这里是:

    ['./DRIVE/test/image\\01.png', './DRIVE/test/image\\02.png', './DRIVE/test/image\\03.png', './DRIVE/test/image\\04.png', './DRIVE/test/image\\05.png', './DRIVE/test/image\\06.png', './DRIVE/test/image\\07.png', './DRIVE/test/image\\08.png', './DRIVE/test/image\\09.png']

    如图:


     初始化路径和预处理后,需要对图像进行处理

    这里训练的样本和对应的二值图像的label文件名要保证一样,否则需要做别的处理。例如,这里只需要将训练样本的图像路径里面的image 替换(replace)成label 就能找到对应的分割图像

    然后读取图像,预处理之后,在进行返回即可。

    这里为了防止label不是严格的二值图像,在归一化(灰度值 / 255)后,将中间的灰度值也映射为前景像素点

    4. train 训练网络

    训练网络的代码基本上没有改变,这里简单介绍

    判断网络运行的设备,将网络to到device上

    加载训练集+测试集

    这里传入的是训练的样本,因为Data_loader 会将样本的路径替换成 label找到对应分割的标签图像

    因为内存不足,所以这里将batch size 设置成 1

    然后定义优化器+损失函数,并且保存网络的训练权重文件

    有关BCEWithLogitsLoss可以参考这个:聊聊关于图像分割的损失函数 - BCEWithLogitsLoss

    训练的时候,需要网络在train模式下,然后就是正确的前向传播预测+反向梯度下降的内容

    最后是计算正确率,需要将网络放到eval模式下

    这里将网络的预测转为二值图像,然后计算准确率的方式是预测的二值图像和label进行逐个像素点的比对,最后比上整幅图像的空间分辨率,即图像的大小。

    test_label 的通道顺序是:batch、channel、height、width

    5. predict 分割图像

    这里的预处理要和处理样本的预处理一致

    加载网络+读取网络参数

    预测的时候,需要扩展维度。保存图像的时候,需要将batch和channel减去

    然后将预测的结果转为二值图像就可以了

    6. show

    训练了20个epoch,结果显示如下

    这里来预测的图像在test数据集里面,predict里面的图像为:

    UNet 分割的结果:

    真实的label为:

    分割了大部分的信息,但是仍有细节没有分割出来

    图像的size 是 565*584 的,大概预测的准确率是 0.96 左右

    也就是说 还有 565*584*0.04 = 13198 ,这些损失的像素点就是缺少的细节

    7. 完整代码

    model部分:

    1. import torch.nn as nn
    2. import torch
    3. import torch.nn.functional as F
    4. # 搭建unet 网络
    5. class DoubleConv(nn.Module): # 连续两次卷积
    6. def __init__(self,in_channels,out_channels):
    7. super(DoubleConv,self).__init__()
    8. self.double_conv = nn.Sequential(
    9. nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False),
    10. nn.BatchNorm2d(out_channels), # 用 BN 代替 Dropout
    11. nn.ReLU(inplace=True),
    12. nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,bias=False),
    13. nn.BatchNorm2d(out_channels),
    14. nn.ReLU(inplace=True)
    15. )
    16. def forward(self,x):
    17. x = self.double_conv(x)
    18. return x
    19. class Down(nn.Module): # 下采样
    20. def __init__(self,in_channels,out_channels):
    21. super(Down, self).__init__()
    22. self.downsampling = nn.Sequential(
    23. nn.MaxPool2d(kernel_size=2,stride=2),
    24. DoubleConv(in_channels,out_channels)
    25. )
    26. def forward(self,x):
    27. x = self.downsampling(x)
    28. return x
    29. class Up(nn.Module): # 上采样
    30. def __init__(self, in_channels, out_channels):
    31. super(Up,self).__init__()
    32. self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积
    33. self.conv = DoubleConv(in_channels, out_channels)
    34. def forward(self, x1, x2):
    35. x1 = self.upsampling(x1)
    36. diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) # 确保任意size的图像输入
    37. diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
    38. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
    39. diffY // 2, diffY - diffY // 2])
    40. x = torch.cat([x2, x1], dim=1) # 从channel 通道拼接
    41. x = self.conv(x)
    42. return x
    43. class OutConv(nn.Module): # 最后一个网络的输出
    44. def __init__(self, in_channels, num_classes):
    45. super(OutConv, self).__init__()
    46. self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
    47. def forward(self, x):
    48. return self.conv(x)
    49. class UNet(nn.Module): # unet 网络
    50. def __init__(self, in_channels = 1, num_classes = 1):
    51. super(UNet, self).__init__()
    52. self.in_channels = in_channels
    53. self.num_classes = num_classes
    54. self.in_conv = DoubleConv(in_channels, 64)
    55. self.down1 = Down(64, 128)
    56. self.down2 = Down(128, 256)
    57. self.down3 = Down(256, 512)
    58. self.down4 = Down(512, 1024)
    59. self.up1 = Up(1024, 512)
    60. self.up2 = Up(512, 256)
    61. self.up3 = Up(256, 128)
    62. self.up4 = Up(128, 64)
    63. self.out_conv = OutConv(64, num_classes)
    64. def forward(self, x):
    65. x1 = self.in_conv(x)
    66. x2 = self.down1(x1)
    67. x3 = self.down2(x2)
    68. x4 = self.down3(x3)
    69. x5 = self.down4(x4)
    70. x = self.up1(x5, x4)
    71. x = self.up2(x, x3)
    72. x = self.up3(x, x2)
    73. x = self.up4(x, x1)
    74. x = self.out_conv(x)
    75. return x

    dataset 数据处理部分:

    1. import os
    2. from torch.utils.data import Dataset
    3. from PIL import Image
    4. from torchvision import transforms
    5. data_transform = {
    6. "train": transforms.Compose([transforms.ToTensor(),
    7. transforms.Normalize((0.5, ), (0.5, ))]),
    8. "test": transforms.Compose([transforms.ToTensor()])
    9. }
    10. # 数据处理文件
    11. class Data_Loader(Dataset): # 加载数据
    12. def __init__(self, root, transforms_train=data_transform['train'],transforms_test=data_transform['test']): # 初始化
    13. imgs = os.listdir(root) # 读取图像的路径
    14. self.imgs = [os.path.join(root,img) for img in imgs] # 取出路径下所有的图片
    15. self.transforms_train = transforms_train # 预处理
    16. self.transforms_test = transforms_test
    17. def __getitem__(self, index): # 获取数据、预处理等等
    18. image_path = self.imgs[index] # 根据index读取图片
    19. label_path = image_path.replace('image', 'label') # 根据image_path生成label_path
    20. image = Image.open(image_path) # 读取图片和对应的label图
    21. label = Image.open(label_path)
    22. image = self.transforms_train(image) # 样本预处理
    23. label = self.transforms_test(label) # label 预处理
    24. label[label > 0] = 1
    25. return image, label
    26. def __len__(self): # 返回样本的数量
    27. return len(self.imgs)

    train 网络训练部分:

    1. from model import UNet
    2. from dataset import Data_Loader
    3. from torch import optim
    4. import torch.nn as nn
    5. import torch
    6. # 网络训练模块
    7. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # GPU or CPU
    8. print(device)
    9. net = UNet(in_channels=1, num_classes=1) # 加载网络
    10. net.to(device) # 将网络加载到device上
    11. # 加载训练集
    12. trainset = Data_Loader("./DRIVE/train/image")
    13. train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=1,shuffle=True)
    14. len = len(trainset) # 样本总数为 31
    15. # 加载测试集
    16. testset = Data_Loader("./DRIVE/test/image")
    17. test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=1)
    18. # 加载优化器和损失函数
    19. optimizer = optim.RMSprop(net.parameters(), lr=0.00001,weight_decay=1e-8, momentum=0.9) # 定义优化器
    20. criterion = nn.BCEWithLogitsLoss() # 定义损失函数
    21. # 保存网络参数
    22. save_path = './UNet.pth' # 网络参数的保存路径
    23. best_acc = 0.0 # 保存最好的准确率
    24. # 训练
    25. for epoch in range(20):
    26. net.train() # 训练模式
    27. running_loss = 0.0
    28. for image,label in train_loader:
    29. optimizer.zero_grad() # 梯度清零
    30. pred = net(image.to(device)) # 前向传播
    31. loss = criterion(pred, label.to(device)) # 计算损失
    32. loss.backward() # 反向传播
    33. optimizer.step() # 梯度下降
    34. running_loss += loss.item() # 计算损失和
    35. net.eval() # 测试模式
    36. acc = 0.0 # 正确率
    37. total = 0
    38. with torch.no_grad():
    39. for test_image, test_label in test_loader:
    40. outputs = net(test_image.to(device)) # 前向传播
    41. outputs[outputs >= 0] = 1 # 将预测图片转为二值图片
    42. outputs[outputs < 0] = 0
    43. # 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数
    44. acc += (outputs == test_label.to(device)).sum().item() / (test_label.size(2) * test_label.size(3))
    45. total += test_label.size(0)
    46. accurate = acc / total # 计算整个test上面的正确率
    47. print('[epoch %d] train_loss: %.3f test_accuracy: %.3f %%' %
    48. (epoch + 1, running_loss/len, accurate*100))
    49. if accurate > best_acc: # 保留最好的精度
    50. best_acc = accurate
    51. torch.save(net.state_dict(), save_path) # 保存网络参数

    predict 预测部分:

    1. import numpy as np
    2. import torch
    3. import cv2
    4. from model import UNet
    5. from torchvision import transforms
    6. from PIL import Image
    7. transform = transforms.Compose([
    8. transforms.ToTensor(),
    9. transforms.Normalize((0.5,),(0.5))
    10. ])
    11. # 加载模型
    12. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    13. net = UNet(in_channels=1, num_classes=1)
    14. net.load_state_dict(torch.load('UNet.pth', map_location=device))
    15. net.to(device)
    16. # 测试模式
    17. net.eval()
    18. with torch.no_grad():
    19. img = Image.open('./predict/img.png') # 读取预测的图片
    20. img = transform(img) # 预处理
    21. img = torch.unsqueeze(img,dim = 0) # 增加batch维度
    22. pred = net(img.to(device)) # 网络预测
    23. pred = torch.squeeze(pred) # 将(batch、channel)维度去掉
    24. pred = np.array(pred.data.cpu()) # 保存图片需要转为cpu处理
    25. pred[pred >=0 ] =255 # 转为二值图片
    26. pred[pred < 0 ] =0
    27. pred = np.uint8(pred) # 转为图片的形式
    28. cv2.imwrite('./result/res.png', pred) # 保存图片

  • 相关阅读:
    Linux socket编程(5):三次握手和四次挥手分析和SIGPIPE信号的处理
    【K 均值聚类】02/5:简介
    设计模式之策略模式
    flex:1详解,以及flex:1和flex:auto的区别
    android studio 自带模拟器进行 Root 及 Xposed安装
    Android学习笔记 65. 数据绑定基础知识
    golang Context应用举例
    Docker 的基本概念和优势
    大数据:HDFS的Shell常用命令操作
    芯动联科冲刺科创板:年营收1.7亿 北方电子院与中城创投是股东
  • 原文地址:https://blog.csdn.net/qq_44886601/article/details/128188184