• UNet语义分割网络


    1、本文参考

    [炼丹术]UNet图像分割模型相关总结_animalslin的技术博客_51CTO博客

    https://cuijiahua.com/blog/2019/11/dl-14.html

    Pytorch 深度学习实战教程(二):UNet语义分割网络 - 腾讯云开发者社区-腾讯云

    2、UNet网络介绍

    UNet网络用于语义分割

    语义就是给图像上目标类别中的每一点打一个标签,使得不同种类的东西在图像上被区分开来。可以理解成像素级别的分类任务,即对每个像素点进行分类。

    假如存在五类:Person(人)、Purse(包)、Plants/Grass(植物/草)、Sidewalk(人行道)、Building/Structures(建筑物)。需要创建一个one-hot编码的目标类别标注,即为每个类别创建一个输出通道。因为有5个类别,所以网络输出的通道数也为5,如下图所示:

     

     

    因为不存在同一个像素点在两个以上的通道均为1的情况(存疑),所以预测的结果可以通过对每个像素在深度上求argmax的方式被整合到一张分割图中,进而可以通过重叠的方式观察到每个目标。

    UNet网络的架构如下(实际实施时思想不变,但是略有调整):

     

    3、UNet训练整体方案

    (1)通过labelme进行语义标注,产出结果json文件

    (2)编写代码,根据json文件的points信息,从原图中获取mask图片

    (3)在UNet网络中,输入3通道图片,输出预测的1通道mask(假定只有一个识别类别),将预测的mask和实际的mask计算BCELoss从而进行拟合操作,并且输出准确率和dice score的监控指标

    4、UNet网络实施分析

    (1)labelme进行多边形标注

     标注完成后,会在图片所在目录下生成json文件。

    (2)根据json文件生成mask图片

    文件名:json2mask.py

    1. import os
    2. import cv2
    3. import numpy as np
    4. from PIL import Image, ImageDraw
    5. import json
    6. CLASS_NAMES = ['dog', 'cat']
    7. def make_mask(image_dir, save_dir):
    8. data = os.listdir(image_dir)
    9. temp_data = []
    10. for i in data:
    11. if i.split('.')[1] == 'json':
    12. temp_data.append(i)
    13. else:
    14. continue
    15. for js in temp_data:
    16. json_data = json.load(open(os.path.join(image_dir, js), 'r'))
    17. shapes_ = json_data['shapes']
    18. mask = Image.new('P', Image.open(os.path.join(image_dir, js.replace('json', 'jpg'))).size)
    19. for shape_ in shapes_:
    20. label = shape_['label']
    21. points = shape_['points']
    22. points = tuple(tuple(i) for i in points)
    23. mask_draw = ImageDraw.Draw(mask) # 类似于函数声明
    24. mask_draw.polygon(points, fill=CLASS_NAMES.index(label) + 1)
    25. mask = np.array(mask) * 255
    26. cv2.imshow('mask', mask)
    27. cv2.waitKey(0)
    28. cv2.imwrite(os.path.join(save_dir, js.replace('json', 'jpg')), mask)
    29. def vis_label(img):
    30. img = Image.open(img)
    31. img = np.array(img)
    32. print(set(img.reshape(-1).tolist()))
    33. if __name__ == '__main__':
    34. make_mask('D:\\ai_data\\cat\\val', 'D:\\ai_data\\cat\\val_mask')

    说明:

    • Image.new中mode='P',代表生成的图片为8-bit pixels,适合用于生成mask图片
    • 像素值=255表示白色,也就是说mask图片中mask部分为白色,非mask部分为黑色。实际得到的mask图片中mask会存在(249,255)的值,使用时需要再处理下。

    (3)UNet网络构造

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. class DoubleConv(nn.Module):
    5. """(convolution => [BN] => ReLU) * 2"""
    6. def __init__(self, in_channels, out_channels):
    7. super().__init__()
    8. self.double_conv = nn.Sequential(
    9. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
    10. nn.BatchNorm2d(out_channels),
    11. nn.ReLU(inplace=True),
    12. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
    13. nn.BatchNorm2d(out_channels),
    14. nn.ReLU(inplace=True)
    15. )
    16. def forward(self, x):
    17. return self.double_conv(x)
    18. class Down(nn.Module):
    19. """Downscaling with maxpool then double conv"""
    20. def __init__(self, in_channels, out_channels):
    21. super().__init__()
    22. self.maxpool_conv = nn.Sequential(
    23. nn.MaxPool2d(2),
    24. DoubleConv(in_channels, out_channels)
    25. )
    26. def forward(self, x):
    27. return self.maxpool_conv(x)
    28. class Up(nn.Module):
    29. """Upscaling then double conv"""
    30. def __init__(self, in_channels, out_channels, bilinear=True):
    31. super().__init__()
    32. if bilinear:
    33. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    34. else:
    35. self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
    36. self.conv = DoubleConv(in_channels, out_channels)
    37. def forward(self, x1, x2):
    38. x1 = self.up(x1)
    39. # input is NCHW
    40. diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
    41. diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
    42. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
    43. diffY // 2, diffY - diffY // 2])
    44. x = torch.cat([x2, x1], dim=1)
    45. return self.conv(x)
    46. class OutConv(nn.Module):
    47. def __init__(self, in_channels, out_channels):
    48. super(OutConv, self).__init__()
    49. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    50. def forward(self, x):
    51. return self.conv(x)
    52. class UNet(nn.Module):
    53. def __init__(self, n_channels, n_classes, bilinear=False):
    54. super(UNet, self).__init__()
    55. self.n_channels = n_channels
    56. self.n_classes = n_classes
    57. self.bilinear = bilinear
    58. self.inc = DoubleConv(n_channels, 64)
    59. self.down1 = Down(64, 128)
    60. self.down2 = Down(128, 256)
    61. self.down3 = Down(256, 512)
    62. self.down4 = Down(512, 1024)
    63. self.up1 = Up(1024, 512, bilinear)
    64. self.up2 = Up(512, 256, bilinear)
    65. self.up3 = Up(256, 128, bilinear)
    66. self.up4 = Up(128, 64, bilinear)
    67. self.outc = OutConv(64, n_classes)
    68. def forward(self, x):
    69. x1 = self.inc(x)
    70. x2 = self.down1(x1)
    71. x3 = self.down2(x2)
    72. x4 = self.down3(x3)
    73. x5 = self.down4(x4)
    74. x = self.up1(x5, x4)
    75. x = self.up2(x, x3)
    76. x = self.up3(x, x2)
    77. x = self.up4(x, x1)
    78. logits = self.outc(x)
    79. return logits
    80. if __name__ == '__main__':
    81. net = UNet(n_channels=3, n_classes=1)
    82. print(net)
    83. x = torch.randn([1, 3, 572, 572])
    84. out = net(x)
    85. print(out.shape)

     说明:

    • 本代码为按照UNet论文构造的网络,实际中并未使用该代码,需要稍作修改
    • 本代码适合阅读UNet网络,不明白之处可参考:Pytorch 深度学习实战教程(二):UNet语义分割网络 - 腾讯云开发者社区-腾讯云
    •  下采样通过卷积核最大池化完成
    • 上采样通过转置卷积以及和下采样的特征concat完成。 
    • 上采样是会将通道维度减少一半,比如1024到512,因为和下采样的特征(同样也是512)在dim=1(channel维度)进行了concat,所以channel维度的值又变为了1024.
    • 论文的输入w、h和输出的w‘、h’大小不一样,在mask图片比对时会有问题,所以我们希望输入和输出的wh保持一致。此时会设置padding=1,这样double_conv时候w、h会保持不变,只有池化时变为原来的一半,上采样时候又会变为原来的两倍。
    • 以上代码只作为了解UNet网络使用,不作为整个工程的代码

    (4)主函数train.py

    1. import torch
    2. import albumentations as A
    3. from albumentations.pytorch import ToTensorV2
    4. from tqdm import tqdm
    5. import torch.nn as nn
    6. import torch.optim as optim
    7. from model import UNET
    8. # from unet_model_new import UNet
    9. from utils import (
    10. load_checkpoint,
    11. save_checkpoint,
    12. get_loaders,
    13. check_accuracy,
    14. save_predictions_as_imgs,
    15. )
    16. # 超参
    17. learning_rate = 1e-4
    18. device = 'cpu'
    19. batch_size = 1
    20. num_epochs = 30
    21. num_workers = 0
    22. image_height = 160
    23. image_width = 240
    24. pin_memory = False
    25. load_model = False
    26. train_img_dir = "D:\\ai_data\\cat\\train2"
    27. train_mask_dir = "D:\\ai_data\\cat\\train2_mask"
    28. val_img_dir = "D:\\ai_data\\cat\\val2"
    29. val_mask_dir = "D:\\ai_data\\cat\\val2_mask"
    30. def train_fn(loader, model, optimizer, loss_fn):
    31. for batch_idx, (data, targets) in enumerate(tqdm(loader)):
    32. data = data.to(device=device)
    33. targets = targets.float().unsqueeze(1).to(device=device)
    34. predictions = model(data)
    35. loss = loss_fn(predictions, targets)
    36. optimizer.zero_grad()
    37. loss.backward()
    38. def main():
    39. train_transform = A.Compose(
    40. [
    41. A.Resize(height=image_height, width=image_width),
    42. A.Rotate(limit=35, p=1.0),
    43. A.HorizontalFlip(p=0.5),
    44. A.VerticalFlip(p=0.1),
    45. A.Normalize(
    46. mean=[0.0, 0.0, 0.0],
    47. std=[1.0, 1.0, 1.0],
    48. max_pixel_value=255.0
    49. ),
    50. ToTensorV2(),
    51. ],
    52. )
    53. val_transform = A.Compose(
    54. [
    55. A.Resize(height=image_height, width=image_width),
    56. A.Normalize(
    57. mean=[0.0, 0.0, 0.0],
    58. std=[1.0, 1.0, 1.0],
    59. max_pixel_value=255.0
    60. ),
    61. ToTensorV2(),
    62. ],
    63. )
    64. model = UNET(in_channels=3, out_channels=1).to(device)
    65. loss_fn = nn.BCEWithLogitsLoss()
    66. optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    67. train_loader, val_loader = get_loaders(
    68. train_img_dir,
    69. train_mask_dir,
    70. val_img_dir,
    71. val_mask_dir,
    72. batch_size,
    73. train_transform,
    74. val_transform,
    75. num_workers,
    76. pin_memory
    77. )
    78. if load_model:
    79. load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
    80. check_accuracy(-1, "val", val_loader, model, device=device)
    81. for epoch in range(num_epochs):
    82. train_fn(train_loader, model, optimizer, loss_fn)
    83. checkpoint = {
    84. "state_dict": model.state_dict(),
    85. "optimizer": optimizer.state_dict(),
    86. }
    87. save_checkpoint(checkpoint)
    88. check_accuracy(epoch, "train", train_loader, model, device=device)
    89. check_accuracy(epoch, "val", val_loader, model, device=device)
    90. save_predictions_as_imgs(val_loader, model, folder="saved_images/", device=device)
    91. if __name__ == "__main__":
    92. main()

     (5)数据加载dataset.py

    1. import os
    2. from PIL import Image
    3. from torch.utils.data import Dataset
    4. import numpy as np
    5. class CarvanaDataset(Dataset):
    6. def __init__(self, image_dir, mask_dir, transform=None):
    7. self.image_dir = image_dir
    8. self.mask_dir = mask_dir
    9. self.transform = transform
    10. self.images = os.listdir(image_dir)
    11. def __len__(self):
    12. return len(self.images)
    13. def __getitem__(self, index):
    14. img_path = os.path.join(self.image_dir, self.images[index])
    15. mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.jpg"))
    16. image = np.array(Image.open(img_path).convert("RGB"))
    17. mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
    18. mask[mask > 200.0] = 1.0 # 转换为灰度图后并非全是255白色
    19. if self.transform is not None:
    20. augmentations = self.transform(image=image, mask=mask)
    21. image = augmentations["image"]
    22. mask = augmentations["mask"]
    23. return image, mask

    (6)模型model.py

    1. import torch
    2. import torch.nn as nn
    3. import torch.functional as F
    4. import torchvision.transforms.functional as TF
    5. class DoubleConv(nn.Module):
    6. def __init__(self, in_channels, out_channels):
    7. super(DoubleConv, self).__init__()
    8. self.conv = nn.Sequential(
    9. nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # padding=1,保证conv2d的输出hw保持不变
    10. nn.BatchNorm2d(out_channels),
    11. nn.ReLU(inplace=True),
    12. nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), # padding=1,保证conv2d的输出hw保持不变
    13. nn.BatchNorm2d(out_channels),
    14. nn.ReLU(inplace=True),
    15. )
    16. def forward(self, x):
    17. return self.conv(x)
    18. class UNET(nn.Module):
    19. def __init__(
    20. self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    21. ):
    22. super(UNET, self).__init__()
    23. self.ups = nn.ModuleList()
    24. self.downs = nn.ModuleList()
    25. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    26. # Down part of UNET
    27. for feature in features:
    28. self.downs.append(DoubleConv(in_channels, feature))
    29. in_channels = feature
    30. # Up part of UNET
    31. for feature in reversed(features):
    32. self.ups.append(
    33. nn.ConvTranspose2d(
    34. feature*2, feature, kernel_size=2, stride=2,
    35. )
    36. )
    37. self.ups.append(DoubleConv(feature*2, feature))
    38. self.bottleneck = DoubleConv(features[-1], features[-1]*2)
    39. self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    40. def forward(self, x):
    41. skip_connections = []
    42. for down in self.downs:
    43. x = down(x)
    44. skip_connections.append(x)
    45. x = self.pool(x)
    46. x = self.bottleneck(x)
    47. skip_connections = skip_connections[::-1]
    48. for idx in range(0, len(self.ups), 2):
    49. x = self.ups[idx](x)
    50. skip_connection = skip_connections[idx//2]
    51. if x.shape != skip_connection.shape:
    52. x = TF.resize(x, size=skip_connection.shape[2:]) # 因为有padding=1,所以到不了这一步
    53. # diffY = torch.tensor([skip_connection.size()[2] - x.size()[2]])
    54. # diffX = torch.tensor([skip_connection.size()[3] - x.size()[3]])
    55. # x = F.pad(x, [diffX // 2, diffX - diffX // 2,
    56. # diffY // 2, diffY - diffY // 2])
    57. concat_skip = torch.cat((skip_connection, x), dim=1)
    58. x = self.ups[idx+1](concat_skip)
    59. return self.final_conv(x)
    60. def test():
    61. x = torch.randn((3, 1, 572, 572))
    62. model = UNET(in_channels=1, out_channels=1)
    63. preds = model(x)
    64. assert preds.shape == x.shape
    65. if __name__ == "__main__":
    66. test()

    (7)工具utils.py

    1. import torch
    2. import torchvision
    3. from dataset import CarvanaDataset
    4. from torch.utils.data import DataLoader
    5. def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    6. print("=> Saving checkpoint")
    7. torch.save(state, filename)
    8. def load_checkpoint(checkpoint, model):
    9. print("=> Loading checkpoint")
    10. model.load_state_dict(checkpoint["state_dict"])
    11. def get_loaders(
    12. train_dir,
    13. train_maskdir,
    14. val_dir,
    15. val_maskdir,
    16. batch_size,
    17. train_transform,
    18. val_transform,
    19. num_workers=4,
    20. pin_memory=True,
    21. ):
    22. train_ds = CarvanaDataset(
    23. image_dir=train_dir,
    24. mask_dir=train_maskdir,
    25. transform=train_transform,
    26. )
    27. train_loader = DataLoader(
    28. train_ds,
    29. batch_size=batch_size,
    30. num_workers=num_workers,
    31. pin_memory=pin_memory,
    32. shuffle=True,
    33. )
    34. val_ds = CarvanaDataset(
    35. image_dir=val_dir,
    36. mask_dir=val_maskdir,
    37. transform=val_transform,
    38. )
    39. val_loader = DataLoader(
    40. val_ds,
    41. batch_size=batch_size,
    42. num_workers=num_workers,
    43. pin_memory=pin_memory,
    44. shuffle=False,
    45. )
    46. return train_loader, val_loader
    47. def check_accuracy(epoch, attr, loader, model, device="cuda"):
    48. num_correct = 0
    49. num_pixels = 0
    50. dice_score = 0
    51. model.eval()
    52. with torch.no_grad():
    53. for x, y in loader:
    54. x = x.to(device)
    55. y = y.to(device).unsqueeze(1)
    56. preds = torch.sigmoid(model(x))
    57. preds = (preds > 0.5).float()
    58. num_correct += (preds == y).sum()
    59. num_pixels += torch.numel(preds)
    60. dice_score += (2 * (preds * y).sum()) / (
    61. (preds + y).sum() + 1e-8
    62. )
    63. print(f"{attr}_{epoch+1}: Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    64. print(f"{attr}_{epoch+1}: Dice score: {dice_score/len(loader)}")
    65. model.train()
    66. def save_predictions_as_imgs(
    67. loader, model, folder="saved_images/", device="cuda"
    68. ):
    69. model.eval()
    70. for idx, (x, y) in enumerate(loader):
    71. x = x.to(device=device)
    72. with torch.no_grad():
    73. preds = torch.sigmoid(model(x))
    74. preds = (preds > 0.5).float()
    75. torchvision.utils.save_image(
    76. preds, f"{folder}/pred_{idx}.png"
    77. )
    78. torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
    79. model.train()

    (8)监控指标dice score说明

    参考文档:关于图像分割的评价指标dice_Pierce_KK的博客-CSDN博客_dice评价指标

     dice指标也用在机器学习中,它的表达式为:

    这与机器学习中的评价指标F1是相同的。

    准确率指标:

    召回率指标:

     

    F1则是基于准确率和召回率的调和平均值,即:

     

     

    dice指标是医学图像中的常见指标,常用于评价图像分割算法的好坏。从公式上来做直观的理解,如下图所示,其代表的是两个体相交的面积占总面积的比值,完美分割该值为1.

     

     本试验中,准确率能够达到60%+,disc score只有0.4+,整体效果不佳。

     

    5、 UNet后续发展

    (1)UNet网络的思想:

    • 下采样+上采样作为整体的网络结构(Encoder-Decoder)
    • 多尺度的特征融合
    • 信息流通的方式
    • 获得像素级别的segment map

    (2)对于改进UNet的见解,参考:谈一谈UNet图像分割_3D视觉工坊的博客-CSDN博客

    很多人都喜欢在UNet进行改进,换个优秀的编码器,然后自己在手动把解码器对应实现一下。执御为什么选择UNet上进行改进,可能是因为UNet网络的结构比较简单,而且UNet的效果在很多场景下的表现可能都是差强人意的。 

     UNet最原始的设计思路,相对于后面系列的一个劣势就是:信息融合、位置不偏移。

  • 相关阅读:
    【CTF Web】CTFShow web18 Writeup(文件包含漏洞+日志注入+RCE)
    125. 验证回文串
    ubuntu18.04下zookeeper安装与简单使用
    Ubuntu20.04 从头到尾完整版安装anaconda、cuda、cudnn、pytorch、paddle2.3成功记录
    JS 方法实现复制粘贴
    【lambda表达式】常用的函数式接口
    【csdn】gitcode初体验(开发云、Pages等)(持续更新)
    Nacos-Feign-Gateway
    【Python爬虫原理与基本请求库urllib详解】
    不用任何 js 库,纯前端导出数据到 Excel / CSV 文件就这么简单几行代码
  • 原文地址:https://blog.csdn.net/benben044/article/details/126142576