• pytorch实现Unet


    http://t.zoukankan.com/wanghui-garcia-p-10719121.html

    https://github.com/1024210879/unet-denoising-dirty-documents/blob/master/datasets.py

    在这里插入图片描述

    Model.py

    # sub-parts of the U-Net model
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    # 实现左边的横向卷积
    class double_conv(nn.Module):
        '''(conv => BN => ReLU) * 2'''
    
        def __init__(self, in_ch, out_ch):
            super(double_conv, self).__init__()
            self.conv = nn.Sequential(
                # 以第一层为例进行讲解
                # 输入通道数in_ch,输出通道数out_ch,卷积核设为kernal_size 3*3,padding为1,stride为1,dilation=1
                # 所以图中H*W能从572*572 变为 570*570,计算为570 = ((572 + 2*padding - dilation*(kernal_size-1) -1) / stride ) +1
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),  # 进行批标准化,在训练时,该层计算每次输入的均值与方差,并进行移动平均
                nn.ReLU(inplace=True),  # 激活函数
                nn.Conv2d(out_ch, out_ch, 3, padding=1),  # 再进行一次卷积,从570*570变为 568*568
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
        def forward(self, x):
            x = self.conv(x)
            return x
    
    # 实现左边第一行的卷积
    class inconv(nn.Module):  #
        def __init__(self, in_ch, out_ch):
            super(inconv, self).__init__()
            self.conv = double_conv(in_ch, out_ch)  # 输入通道数in_ch为3, 输出通道数out_ch为64
        def forward(self, x):
            x = self.conv(x)
            return x
    
    # 实现左边的向下池化操作,并完成另一层的卷积
    class down(nn.Module):
        def __init__(self, in_ch, out_ch):
            super(down, self).__init__()
            self.mpconv = nn.Sequential(
                nn.MaxPool2d(2),
                double_conv(in_ch, out_ch)
            )
        def forward(self, x):
            x = self.mpconv(x)
            return x
    
    # 实现右边的向上的采样操作,并完成该层相应的卷积操作
    class up(nn.Module):
        def __init__(self, in_ch, out_ch, bilinear=True):
            super(up, self).__init__()
            #  would be a nice idea if the upsampling could be learned too,
            #  but my machine do not have enough memory to handle all those weights
            if bilinear:  # 声明使用的上采样方法为bilinear——双线性插值,默认使用这个值,计算方法为 floor(H*scale_factor),所以由28*28变为56*56
                self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            else:  # 否则就使用转置卷积来实现上采样,计算式子为 (Height-1)*stride - 2*padding -kernal_size +output_padding
                self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
    
            self.conv = double_conv(in_ch, out_ch)
        def forward(self, x1, x2):  # x2是左边特征提取传来的值
            # 第一次上采样返回56*56,但是还没结束
            x1 = self.up(x1)
    
            # input is CHW, [0]是batch_size, [1]是通道数,更改了下,与源码不同
            diffY = x1.size()[2] - x2.size()[2]  # 得到图像x2与x1的H的差值,56-64=-8
            diffX = x1.size()[3] - x2.size()[3]  # 得到图像x2与x1的W差值,56-64=-8
    
            # 用第一次上采样为例,即当上采样后的结果大小与右边的特征的结果大小不同时,通过填充来使x2的大小与x1相同
            # 对图像进行填充(-4,-4,-4,-4),左右上下都缩小4,所以最后使得64*64变为56*56
            x2 = F.pad(x2, (diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2))
    
            # for padding issues, see
            # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
            # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
    
            # 将最后上采样得到的值x1和左边特征提取的值进行拼接,dim=1即在通道数上进行拼接,由512变为1024
            x = torch.cat([x2, x1], dim=1)
            x = self.conv(x)
            return x
    
    # 实现右边的最高层的最右边的卷积
    class outconv(nn.Module):
        def __init__(self, in_ch, out_ch):
            super(outconv, self).__init__()
            self.conv = nn.Conv2d(in_ch, out_ch, 1)
        def forward(self, x):
            x = self.conv(x)
            return x
    class UNet(nn.Module):
        def __init__(self, in_channels, out_channels): #图片的通道数,1为灰度图像,3为彩色图像
            super(UNet, self).__init__()
            self.inc = inconv(in_channels, 64) #假设输入通道数n_channels为3,输出通道数为64
            self.down1 = down(64, 128)
            self.down2 = down(128, 256)
            self.down3 = down(256, 512)
            self.down4 = down(512, 512)
            self.up1 = up(1024, 256)
            self.up2 = up(512, 128)
            self.up3 = up(256, 64)
            self.up4 = up(128, 64)
            self.outc = outconv(64, out_channels)
    
        def forward(self, x):
            x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)
            x5 = self.down4(x4)
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1)
            x = self.outc(x)
            return x
            # return F.sigmoid(x) #进行二分类
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117

    dataset.py

    import torch
    import os
    import numpy as np
    import transforms as Transforms
    from torch.utils.data import Dataset
    
    
    class UNetDataset(Dataset):
        def __init__(self, dir_train, dir_mask, transform=None):
            self.dirTrain = dir_train
            self.dirMask = dir_mask
            self.transform = transform
            self.dataTrain = [os.path.join(self.dirTrain, filename)
                              for filename in os.listdir(self.dirTrain)]
                              # if filename.endswith('.jpg') or filename.endswith('.png')]
            self.dataMask = [os.path.join(self.dirMask, filename)
                             for filename in os.listdir(self.dirMask)]
                             # if filename.endswith('.jpg') or filename.endswith('.png')]
            self.trainDataSize = len(self.dataTrain)
            self.maskDataSize = len(self.dataMask)
    
        def __getitem__(self, index):
            assert self.trainDataSize == self.maskDataSize
            image = np.fromfile(self.dataTrain[index], dtype='int16')
            image = np.reshape(image,(512,512))
            label = np.fromfile(self.dataMask[index], dtype='int16')
            label = np.reshape(label, (512,512))
            label = label - image
            # image = cv2.imread(self.dataTrain[index])
            # label = cv2.imread(self.dataMask[index])
    
            if self.transform:
                for method in self.transform:
                    image, label = method(image, label)
    
            return image[np.newaxis], label[np.newaxis]
    
        def __len__(self):
            assert self.trainDataSize == self.maskDataSize
            return self.trainDataSize
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    train.py

    损失采用L1 loss

    import torch
    import torch.nn as nn
    from torch import optim
    import os
    from unet import UNet
    from datasets import UNetDataset
    import transforms as Transforms
    from torch.utils.data import DataLoader
    
    if not os.path.exists('./weight'):
        os.mkdir('./weight')
    LR = 1e-3
    EPOCH = 250
    BATCH_SIZE = 4
    weight = './weight/weight.pth'
    weight_with_optimizer = './weight/weight_with_optimizer.pth'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    
    def train():
    
        # dataset
        transforms = [
            # Transforms.ToGray(),
            # Transforms.RondomFlip(),
            # Transforms.RandomRotate(15),
            Transforms.RandomCrop(128,128),
            # Transforms.Log(0.5),
            # Transforms.EqualizeHist(0.5),
            # Transforms.Blur(0.2),
            # Transforms.ToTensor()
        ]
        dataset = UNetDataset(r'D:\DataSet\artifact\artifact_part\input', r'D:\DataSet\artifact\artifact_part\target', transform=None)
        dataLoader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
        # init model
        net = UNet(1, 1).to(device)
        optimizer = optim.Adam(net.parameters(), lr=LR)
        # loss_func = nn.CrossEntropyLoss().to(device)
        loss_func = nn.L1Loss(reduction='mean')
        # L1 LOSS
        # load weight
        if os.path.exists(weight_with_optimizer):
            checkpoint = torch.load(weight_with_optimizer)
            net.load_state_dict(checkpoint['net'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('load weight')
    
        # train
        for epoch in range(EPOCH):
            # train
            total_loss = 0
            for step, (batch_x, batch_y) in enumerate(dataLoader):
                # import cv2
                # import numpy as np
                # display = np.concatenate(
                #     (batch_x[0][0].numpy(), batch_y[0][0].numpy().astype(np.float32)),
                #     axis=1
                # )
                # cv2.imshow('display', display)
                # cv2.waitKey()
                nstep = len(dataLoader)
                batch_x = batch_x.to(device).float()
                batch_y = batch_y.to(device).float()
                output = net(batch_x)   # torch.float32
                loss = loss_func(output, batch_y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss
                if step % 50 == 0:
                    print("epoch: [%3d/%d] Batch:[%5d/%5d] | loss: %.4f"
                          % (epoch, EPOCH, step, nstep, loss.data.cpu()))
    
            mean_loss = total_loss / nstep
    
            print('epoch: %d | loss: %.4f' % (epoch, mean_loss.data.cpu()))
    
            # save weight
            if (epoch + 1) % 1 == 0:
                torch.save({
                    'net': net.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, weight_with_optimizer)
                torch.save({
                    'net': net.state_dict()
                }, weight)
                print('saved')
    
    
    if __name__ == '__main__':
        train()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
  • 相关阅读:
    海外IP代理科普——API代理是什么?怎么用?
    进程初识
    vi/vim 删除:一行, 一个字符, 单词, 每行第一个字符 命令
    【学习笔记】元学习如何解决计算机视觉少样本学习的问题?
    为什么免费证书的有效期为90天
    STC8H开发(十六): GPIO驱动XL2400无线模块
    今 年 测 试 行 业 企 业 招 聘 真 相...
    终于解决VScode中python/C++打印中文全是乱码的问题了
    处理耗时任务
    链表的边界
  • 原文地址:https://blog.csdn.net/weixin_44708254/article/details/126725040