• U2Net使用方法和实现多类别语义分割模型改造


    作者的碎碎念:U2Net是用来实现SOD的语义分割,本篇论文会介绍算法内容、主要代码、使用方法,以及如何将二分类语义分割修改为多类别语义模型。如果只想知道怎么训练自己的数据集,或者如何修改网络,可以通过目录进行跳转。
    欢迎点赞、评论或收藏❤️


    (一)相关链接

    1. 论文名称
      《U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection》
    2. github链接
      https://github.com/xuebinqin/U-2-Net
    3. paper
      https://arxiv.org/pdf/2005.09007.pdf

    (二)算法内容

    1. 摘要

      U²-Net是显著物体检测(salient object detection,简写SOD)的一个网络,并且现在已经是Python的抠图工具Rembg的基础算法

    • 什么是SOD?
        SOD是模拟人类视觉感知系统来定位场景中最吸引人的目标,例如人像
    • 算法优点总结
      (1)能获取到更多的上下文信息(RSU块,ReSidual U-blocks)
      (2)增加网络深度但没有增加计算量。并且可以从0开始训练,不用从分类预训练网络中再训练
    • 模型大小
        U2-Net (176.3 MB, 30 FPS on GTX 1080Ti GPU)
        U2-Net†(4.7 MB, 40 FPS)

    2. 介绍

    • 现有的SOD网络存在什么问题?
      (1)现有的模式基本都是使用已有的backbone,例如AlexNet、VGG、ResNet。这些基础的网络都是为分类任务而设计的,提取的特征更多是语义特征,而不是定位特征和全局对比的信息。
      (2)耗用大量的资源
      (3)牺牲高分辨率的特征映射来实现更深层次的体系结构
    • U2Net的目标是网络更深、使用的资源和计算量更少、能够保持高分辨率的特征图。怎么做呢?
      (1)用两级的内嵌U型结构,不使用分类的backbone
      (2)新型的网络结构更深、能获取高分辨率图像、不增加内存和计算量

    3. 网络架构

    • 卷积结构和RSU结构比对
      在这里插入图片描述

    (1)( a ) Plain convolution blockPLN
         ( b ) Residual-like block RES
         ( c ) Dense-like block DSE
         ( d ) Inception-like block INC
         ( e ) Our residual U-blockRSU
    (2)(a)到( c )是典型的卷积结构,用了1x1和3x3的卷积,感受野太小,只能用来获取local feature
    (3)(d)用了空洞卷积增大了感受野,但是需要大的内存和计算资源
    (4)RSU-L模块,(L代表层数),Cin:输入通道,Cout:输出通道,M:RSU内部通道

    • 开销比对
      在这里插入图片描述
      RSU的开销(overhead)不大,因为都是下采样,DSE和INC比较大
    • 残差结构比对
      在这里插入图片描述
      (1)残差块:H(x) = F2(F1(x))+x,H(x)是x的映射,F1和F2是卷积操作【对应两个weight layer】
      (2)RSU:HRSU (x) = U(F1(x))+F1(x),RSU和残差不同的地方,是将卷积替换成像Unet的U型结构U-block,原来的输入x替换成F1(x)【weight layer之后】
    • 网络架构
      在这里插入图片描述

      U-Net-like这种结构本来就有,只不过是级联起来,Uxn Net,而作者提出来的是 Un Net,用内嵌(nested)结构而不是级联结构
    (1)结构特点:11个stage,每个stage都是RSU结构
       🔸 a six stages encoder
       🔸a five stages decoder
       🔸a saliency map fusion module attached with the decoder stages and the last encoder stage
    (2)编码器:
       🔹En_1、En_2、En_3、En_4(即前四个)用到的RSU层数是 RSU-7、 RSU-6、 RSU-5、 RSU-4,层数越多,尺度信息越丰富
       🔹En-5和En-6用了RSU-4F,用了空洞卷积,保证了输入输出是相同的分辨率
    (3)解码器:
       De-5也是用了RSU-4F,和En-5、En-6类似
    (4)融合模块(saliency map fusion module):
       编码器和解码器的输出,经过3x3卷积和sigmoid,upsample,输出了6个概率热力图:S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1) ,用1x1卷积进行融合,产生了S_fuse

    4. loss函数

    在这里插入图片描述
    ✅总Loss等于所有loss之和,包括S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1),和融合的S_fuse
    在这里插入图片描述
    ✅每一层的S_side(x)的loss,使用了二分类交叉熵损失函数

    5. 作者实验结果

    在这里插入图片描述
    Red, Green, and Blue indicate the best, second best and third best performance
    在这里插入图片描述

    (三)如何训练自己的数据

    1. 标注

    用labelme标注图片,生成json文件
    在这里插入图片描述

    2. mask图像

    将json文件转换为mask图片,背景黑色,物体白色,下面是转换代码:

    import cv2
    import json
    import numpy as np
    import os
    import sys
    
    
    def func(file:str) -> np.ndarray:
        with open(file, mode='r', encoding="utf-8") as f:
            configs = json.load(f)
        shapes = configs["shapes"]
    
        png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)
    
        for shape in shapes:
            cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (255,255,255))
    
        return png
    
    
    if  __name__ == "__main__":
    
        if len(sys.argv) != 3:
            raise ValueError("json文件或目录 输出路径")
    
        if os.path.isdir(sys.argv[1]):
            for file in os.listdir(sys.argv[1]):
                cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(file)[0]+".png" ), func(os.path.join(sys.argv[1], file)))
        else:
            cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(os.path.basename(sys.argv[1]))[0]+".png"), func(sys.argv[1]))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    在这里插入图片描述

    转换的mask图像

    3. 训练数据集格式

    1️⃣在工程目录创建目录:train_data/DUTS/DUTS-TR/DUTS-TR/
    2️⃣在第一步骤创建的目录上,创建目录im_aug,将原图放在这
    3️⃣在第一步骤创建的目录上,创建目录gt_aug,将转换的mask图放在这

    4. 配置文件修改

      打开u2net_train.py,一般可以设置这几项:
      model_name = ‘u2net’ # 用u2net或者u2netp模型进行训练
      epoch_num = 100000 # 训练轮次
      batch_size_train = 12 # batchsize
      save_frq = 2000 # 每2000个iter保存一个模型

    5. 训练命令

    python u2net_train.py
    
    • 1

    6. 测试命令

    python u2net_test.py
    
    • 1

    (四)多类别语义分割

      作者提供的代码只实现了二分类的语义分割,U2Net是否可以用来做多类别的语义分割?答案是可以了,下面提供了将二分类语义分割转换为多类别语义分割的方法

    2023.11.14新增完整代码在资源,0积分即可下载,有需要的自取
    https://download.csdn.net/download/jin__9981/88533519

    1. 实现思路

    🔺项目背景:图片有两个类别,分别是螺丝钉和位移线
    🔺类别:两个类别+背景,num_class = 3,如果有更多类别,则是n+1类,1是背景
    🔺mask图片:二分类时,填充的是0和255;多分类,不同类别可以填充为0(背景)、1(螺丝钉)、2(位移线),所以最多只能分出0~255个类别。查看3个类别的mask,因为像素值只有0、1、2,肉眼看基本是一张黑色图像
    🔺模型输出:三个类别,输出三个通道,如[3, 320, 320],每一个通道代表一个类别

    2. 修改方法

    (1)获取多类别训练mask脚本

    import cv2
    import json
    import numpy as np
    import os
    import sys
    
    
    def func(file):
        with open(file, mode='r', encoding="utf-8") as f:
            configs = json.load(f)
        shapes = configs["shapes"]
    
        png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)
    
        for shape in shapes:
            label = shape['label']
            if label == 'lm':
                cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (1,1,1))
            else:
                cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (2,2,2))
    
        return png
    
    
    if  __name__ == "__main__":
        json_dir = "./train_data/labels_json"
        
        save_dir = './train_data/masks'
    
    
        for file in os.listdir(json_dir):
            print(file)
            png = func(os.path.join(json_dir, file))
            print(png.shape)
            save_path = save_dir+'/'+os.path.splitext(file)[0]+".png"
            cv2.imwrite(save_path, png)
            print(save_path)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    (2)data_loader.py
       class ToTensor(object)和class ToTensorLab(object)这两个类中,有对label进行归一化操作,去除该操作,因为计算loss的时候,多类别换成交叉熵损失函数,它本身包含了softmax操作
    在这里插入图片描述
    (3)model/u2net.py
       修改模型输出,作者在class U2NETP(nn.Module)和class U2NET(nn.Module)这两个类用了sigmoid函数,需要修改为直接输出,原因同上

    # return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
    return d0, d1, d2, d3, d4, d5, d6
    
    • 1
    • 2

    (4)u2net_train.py
       修改损失函数和模型输出通道,将损失函数由原来的BCELoss,修改为CrossEntropyLoss,并设置模型的输出通道和类别一致

    # bce_loss = nn.BCELoss(size_average=True)  # 注释
    ce_loss = nn.CrossEntropyLoss()  # 添加
    
    • 1
    • 2
    # def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 注释
    #     loss0 = bce_loss(d0, labels_v)
    #     loss1 = bce_loss(d1, labels_v)
    #     loss2 = bce_loss(d2, labels_v)
    #     loss3 = bce_loss(d3, labels_v)
    #     loss4 = bce_loss(d4, labels_v)
    #     loss5 = bce_loss(d5, labels_v)
    #     loss6 = bce_loss(d6, labels_v)
    
    #     loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    #     print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    #     loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
    #     loss6.data.item()))
    
    #     return loss0, loss
    
    def muti_ce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 添加
        loss0 = ce_loss(d0, labels_v)
        loss1 = ce_loss(d1, labels_v)
        loss2 = ce_loss(d2, labels_v)
        loss3 = ce_loss(d3, labels_v)
        loss4 = ce_loss(d4, labels_v)
        loss5 = ce_loss(d5, labels_v)
        loss6 = ce_loss(d6, labels_v)
    
        loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
        print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
        loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
        loss6.data.item()))
    
        return loss0, loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    # ------- 3. define model --------
    # define the net
    n_class = 3
    if (model_name == 'u2net'):
        net = U2NET(3, n_class)
    elif (model_name == 'u2netp'):
        net = U2NETP(3, n_class)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    4. 测试

       该例子中,存在三个类别,分别是背景、螺丝钉、位移线,对应模型三个通道的输出,但模型输出为概率值,如何获取到真实的类别,以及将类别用不同颜色表示出来?可以用下面这个脚本实现模型推理和输出结果图

    import os
    import cv2
    from skimage import io, transform
    import torch
    import torchvision
    from torch.autograd import Variable
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms#, utils
    # import torch.optim as optim
    
    import numpy as np
    from PIL import Image
    import glob
    
    from data_loader import RescaleT
    from data_loader import ToTensor
    from data_loader import ToTensorLab
    from data_loader import SalObjDataset
    
    from model import U2NET # full size version 173.6 MB
    from model import U2NETP # small version u2net 4.7 MB
    
    # normalize the predicted SOD probability map
    def normPRED(d):
        ma = torch.max(d)
        mi = torch.min(d)
    
        dn = (d-mi)/(ma-mi)
    
        return dn
    
    def save_output(image_name,pred,d_dir):
    
        predict = pred
        predict = predict.squeeze()
        predict_np = predict.cpu().data.numpy()
    
        im = Image.fromarray(predict_np*255).convert('RGB')
        img_name = image_name.split(os.sep)[-1]
        image = io.imread(image_name)
        imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
    
        pb_np = np.array(imo)
    
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1,len(bbb)):
            imidx = imidx + "." + bbb[i]
    
        imo.save(d_dir+imidx+'.png')
    
    def main():
    
        # --------- 1. get image path and name ---------
        model_name='u2net'#u2netp
    
        num_class = 3
    
        image_dir = os.path.join(os.getcwd(), 'test_data', 'ls_test_images')
        prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results_ls' + os.sep)
        model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, 'u2net_bce_itr_1000_train_1.046126_tar_0.124982.pth')
    
        img_name_list = glob.glob(image_dir + os.sep + '*')
        print(img_name_list)
    
        # --------- 2. dataloader ---------
        #1. dataloader
        test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                            lbl_name_list = [],
                                            transform=transforms.Compose([RescaleT(320),
                                                                          ToTensorLab(flag=0)])
                                            )
        test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                            batch_size=1,
                                            shuffle=False,
                                            num_workers=1)
    
        # --------- 3. model define ---------
        if(model_name=='u2net'):
            print("...load U2NET---173.6 MB")
            net = U2NET(3,num_class)
        elif(model_name=='u2netp'):
            print("...load U2NEP---4.7 MB")
            net = U2NETP(3,num_class)
    
        if torch.cuda.is_available():
            net.load_state_dict(torch.load(model_dir))
            net.cuda()
        else:
            net.load_state_dict(torch.load(model_dir, map_location='cpu'))
        net.eval()
    
        # --------- 4. inference for each image ---------
        for i_test, data_test in enumerate(test_salobj_dataloader):
    
            print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
    
            inputs_test = data_test['image']
    
            image = cv2.imread(img_name_list[i_test])
            image_name = os.path.basename(img_name_list[i_test])
    
            inputs_test = inputs_test.type(torch.FloatTensor)
    
            if torch.cuda.is_available():
                inputs_test = Variable(inputs_test.cuda())
            else:
                inputs_test = Variable(inputs_test)
    
            d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
            d1 = d1.squeeze(dim=0)    # torch.Size([1, 3, 320, 320]) -> torch.Size([3, 320, 320])
            
            d1 = F.softmax(d1, dim=0)   # [3, 320, 320] 
            # print(d1[0, :, :])
    
            predict_np = torch.argmax(d1, dim=0, keepdim=True)
            # print(predict_np.shape)  # [1, 320, 320],3个类别,对应3个通道,获取概率值最高的下标
    
            predict_np = predict_np.cpu().detach().numpy().squeeze()   # 转到cpu设备
    
            predict_np = cv2.resize(predict_np, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)  # resize和原图一样的大小
            
            r = predict_np.copy()
            b = predict_np.copy()
            g = predict_np.copy()
    
            cls = dict([(1, (0, 0, 255)),
                        (2, (255, 0, 255)),
                        (3, (0, 255, 0)),
                        (4, (255, 0, 0)),
                        (5, (255, 255, 0))])
            for c in cls:
                r[r == c] = cls[c][0]
                g[g == c] = cls[c][1]
                b[b == c] = cls[c][2]
    
            rgb = np.zeros((image.shape[0], image.shape[1], 3))
            # print('类别', np.unique(predict_np))
            rgb[:, :, 0] = r
            rgb[:, :, 1] = g
            rgb[:, :, 2] = b
    
            im = Image.fromarray(rgb.astype(np.uint8))
            im.save('./test_data/my_results_2/' + str(image_name)[:-4] + '.png')
    
            del d1,d2,d3,d4,d5,d6,d7
    
    if __name__ == "__main__":
        main()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153

    5. 训练测试效果

       经过少量数据的训练测试,证明U2Net可以用来做多类别语义分割
    输入图片

    输入测试图片

    在这里插入图片描述

    模型测试效果

    撒花完结🌟🌟🌟

  • 相关阅读:
    Springboot+Mybatis+Mybatisplus 框架中增加自定义分页插件和sql 占位符修改插件
    【Java 进阶篇】使用 Java 和 Jsoup 进行 XML 处理
    upload上传弹窗前二次确认
    Mocha + Chai 测试环境配置,支持 ES6 语法
    C# ZBar解码测试(QRCode、一维码条码)并记录里面隐藏的坑
    408 | 【2017年】计算机统考真题 自用回顾知识点整理
    LightDM简介
    自定义MVC框架02
    centos pip失效
    Linux多任务编程(并发)
  • 原文地址:https://blog.csdn.net/jin__9981/article/details/132712093