• UNet - unet网络


    目录

    1. u-net介绍

    2. u-net网络结构

    3. u-net 网络搭建

    3.1 DoubleConv

    3.2 Down 下采样

    3.3 Up 上采样

    3.4 网络输出

    3.5 UNet 网络

    UNet 网络

    forward  前向传播

    3.6 网络的参数

    4. 完整代码


    1. u-net介绍

    Unet网络是医学图像分割领域常用的分割网络,因为网络的结构很像个U,所以称为Unet

    Unet 网络是针对像素点的分类,之前介绍的LeNet、ResNet等等都是图像分类,最后分的是整幅图像的类别,而Unet是对像素点输出的是前景还是背景的分类

    注:因为Unet 具体的网络框架均有所不同,例如有的连续卷积后会改变图像的size,有的上采样用的是线性插值的方法。这只介绍same卷积和上采样用的转置卷积

    Unet网络是个U型结构,左边是Encoder,右边为Decoder

    左边是下采样的过程,通过减少图像size,增加图像channel来提取特征。

    右边是还原图像的过程,上采样将逐步还原图像的size,这里上采样的输入特征图不仅仅是上一步的输出,还包含了左边对应特征信息。

    2. u-net网络结构

    本章采用的unet网络如图,为了后面数据的训练和预测。这里实现的方式和下图有些细小的区别,具体的会在下面讲解

    首先,网络输入图像的size设定为(480,480)的灰度图像(注意:这里输入是单通道的灰度图)

    然后经过成对的3*3卷积,将图像的深度加深,变成维度为(64,480,480),这里因为图像的size没有变,又因为kernel_size = 3,stride = 1,因此需要保证padding = 1

    接下来是下采样层,先经过一个最大池化层,stride = 2,kernel_size = 2 将图像的size变为原来的一半。然后接两个3*3 的卷积,输出的特征图维度是(128,240,240)

    下采样层总共有四次,根据每次下采样都会将图像的size减半,图像的channel翻倍来计算的话。最后一次图像的size = 480 / (2^4) = 30 ,channel = 64 * (2^4) = 1024 ,所以最后一次下采样图像的维度为(1024,30,30)------> 这里和图上不一样,因为后面用的是转置卷积

    左边的下采样部分实现后,就是右边的上采样部分

    上采样会使图像的channel减半,size变为两倍,正好和下采样的部分反过来。这里利用的操作是转置卷积,转置卷积具体的实现这里不做介绍,主要看它的维度变换。转置卷积变换的公式为:

    out = (in - 1) * stride - 2 * padding + ksize

    这里为了保证图像的size变为两倍,所以要保证 out = 2 * in ,而in的系数2只能从stride来,所以公式变为out = 2 * in - 2 - 2 * padding + ksize ,这里我们让ksize = 2,因此padding = 0 就可以满足要求。而channel的减半只需要把卷积核的个数减半即可

    之前介绍过,最后一层的维度是(1024,30,30),这样通过转置卷积的操作图像的维度就变成了(512,60,60),刚好等于左边下采样的维度!! 所以将它们加在一块,然后进行成对的3*3卷积

    之后就是和下采样的次数一样,重复四次上采样,直到将图像还原成(64,480,480)

    最后一步,如果是图像分类的话,这里应该是全连接层找最大的预测值了。但是Unet是像素点的分类,所以最后产生的也是一副图像,因为这时候图像的size已经是480不需要变了,只需要将图像的channel改变,所以这里只需要一个kernel_size = 1的卷积核就可以了。

    注:最后输出图像的维度是(480,480)的灰度图像,准确的说是二值图像

    3. u-net 网络搭建

    3.1 DoubleConv

    观察unet 网络可以发现,3*3的卷积核都是成对出现的,所以这里将成对卷积核的操作封装成一个类

    1. 因为采用的是两个连续的3*3  卷积,不改变图像的size,所以这里卷积的参数要设置padding=1

    2. ResNet 介绍过,BN代替Dropout 的时候,不需要Bias 

    3. 最后经过ReLU 激活函数

    3.2 Down 下采样

    然后定义下采样的操作

     

    1. 这里下采样采用的就是最大池化层,kernel_size = 2,padding =2 会让图像的size减半

    2. 然后经过两个连续3*3 的卷积

    3. 将 下采样+两个3*3 的卷积 封装成一个新的类Down

    3.3 Up 上采样

    然后是定义上采样

     

    1. 上采样用的是转置卷积,会将图像的size扩大两倍

    2.  注意这里不是定义成 Sequential ,因为 Sequential 会从上到下顺序传播。这里还需要一步尺度融合,就是拼接的操作

    3. 前向传播的时候,图像首先上采样,会将channel减小一半,size扩大两倍。这样就和左边对应的下采样的位置维度一致,将它们通过torch.cat 拼接,dim = 1是因为batch的维度是0 。然后经过两个3*3 的卷积就行了

    3.4 网络输出

    最后网络的输出很简单,经过一个1*1 的卷积核,不改变size的情况下。通过卷积核的个数调整图像的channel就行了

    3.5 UNet 网络

    UNet 网络

    网络的框架很简单,因为每个小的模块已经搭好了,将它们拼接起来就行了

    因为搭建小的模块的时候,我们对于模块的输入都是in和out channel,所以在定义网络的时候,每个模块只要传入对应的channel就行了。

    这里按照UNet 网络的框架设置

     

    forward  前向传播

    前向传播的过程如下:

    在下采样的时候,每个输出都要用变量保存,为了和后面上采样拼接使用

     

    3.6 网络的参数

    1. # 计算 UNet 的网络参数个数
    2. model = UNet(in_channels=1,num_classes=1)
    3. print("Total number of paramerters in networks is {} ".format(sum(x.numel() for x in model.parameters())))

    UNet 网络参数个数为:

     

    4. 完整代码

    代码:

    1. import torch.nn as nn
    2. import torch
    3. # 搭建unet 网络
    4. class DoubleConv(nn.Module): # 连续两次卷积
    5. def __init__(self,in_channels,out_channels):
    6. super(DoubleConv,self).__init__()
    7. self.double_conv = nn.Sequential(
    8. nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False), # 3*3 卷积核
    9. nn.BatchNorm2d(out_channels), # 用 BN 代替 Dropout
    10. nn.ReLU(inplace=True), # ReLU 激活函数
    11. nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,bias=False),
    12. nn.BatchNorm2d(out_channels),
    13. nn.ReLU(inplace=True)
    14. )
    15. def forward(self,x): # 前向传播
    16. x = self.double_conv(x)
    17. return x
    18. class Down(nn.Module): # 下采样
    19. def __init__(self,in_channels,out_channels):
    20. super(Down, self).__init__()
    21. self.downsampling = nn.Sequential(
    22. nn.MaxPool2d(kernel_size=2,stride=2),
    23. DoubleConv(in_channels,out_channels)
    24. )
    25. def forward(self,x):
    26. x = self.downsampling(x)
    27. return x
    28. class Up(nn.Module): # 上采样
    29. def __init__(self, in_channels, out_channels):
    30. super(Up,self).__init__()
    31. self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积
    32. self.conv = DoubleConv(in_channels, out_channels)
    33. def forward(self, x1, x2):
    34. x1 = self.upsampling(x1)
    35. x = torch.cat([x2, x1], dim=1) # 从channel 通道拼接
    36. x = self.conv(x)
    37. return x
    38. class OutConv(nn.Module): # 最后一个网络的输出
    39. def __init__(self, in_channels, num_classes):
    40. super(OutConv, self).__init__()
    41. self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
    42. def forward(self, x):
    43. return self.conv(x)
    44. class UNet(nn.Module): # unet 网络
    45. def __init__(self, in_channels = 1, num_classes = 1):
    46. super(UNet, self).__init__()
    47. self.in_channels = in_channels # 输入图像的channel
    48. self.num_classes = num_classes # 网络最后的输出
    49. self.in_conv = DoubleConv(in_channels, 64) # 第一层
    50. self.down1 = Down(64, 128) # 下采样过程
    51. self.down2 = Down(128, 256)
    52. self.down3 = Down(256, 512)
    53. self.down4 = Down(512, 1024)
    54. self.up1 = Up(1024, 512) # 上采样过程
    55. self.up2 = Up(512, 256)
    56. self.up3 = Up(256, 128)
    57. self.up4 = Up(128, 64)
    58. self.out_conv = OutConv(64, num_classes) # 网络输出
    59. def forward(self, x): # 前向传播 输入size为 (10,1,480,480),这里设置batch = 10
    60. x1 = self.in_conv(x) # torch.Size([10, 64, 480, 480])
    61. x2 = self.down1(x1) # torch.Size([10, 128, 240, 240])
    62. x3 = self.down2(x2) # torch.Size([10, 256, 120, 120])
    63. x4 = self.down3(x3) # torch.Size([10, 512, 60, 60])
    64. x5 = self.down4(x4) # torch.Size([10, 1024, 30, 30])
    65. x = self.up1(x5, x4) # torch.Size([10, 512, 60, 60])
    66. x = self.up2(x, x3) # torch.Size([10, 256, 120, 120])
    67. x = self.up3(x, x2) # torch.Size([10, 128, 240, 240])
    68. x = self.up4(x, x1) # torch.Size([10, 64, 480, 480])
    69. x = self.out_conv(x) # torch.Size([10, 1, 480, 480])
    70. return x
    71. # 计算 UNet 的网络参数个数
    72. model = UNet(in_channels=1,num_classes=1)
    73. print("Total number of paramerters in networks is {} ".format(sum(x.numel() for x in model.parameters())))

  • 相关阅读:
    趣聊粒子滤波器Particle Filter的概念问题
    Item 39: Consider void futures for one-shot event communication.
    AC自动机小结
    分布式BASE理论
    基于BP神经网络、kmeans聚类和HC模型的火焰特征数据识别算法matlab仿真
    Django前后端分离之后端基础3
    【Jetson】使用 Jetson 控制无人车常用指令
    电子学:第012课——实验 11:光和声
    乘法器设计(流水线)verilog code
    遥感云大数据在灾害、水体与湿地领域经典案例及GPT模型应用教程
  • 原文地址:https://blog.csdn.net/qq_44886601/article/details/127855473