目录
项目完整下载地址:UNet 网络对图像的分割
之前已经将unet的网络模块、dataset数据加载和train训练数据已经解决了,这次要将unet网络去分割图像,下面是之前的链接
unet 网络:UNet - unet网络
dataset 数据处理:UNet - 数据加载 Dataset
train 网络训练:UNet - 训练数据train
待分割的图像如下:
存放的路径在U-net项目的predict里面
我们的目标是将predict里面所有的图片分割出来,按照名称顺序保存在result文件夹里面:
首先定义图片的预处理,按照dataset里面相同的方式进行预处理
然后是加载网络的模型和网络参数
然后加载predict里面所有待处理图片的路径
需要注意的是,os.listdir 加载的只是里面每个图片,并不是图片的具体路径。tests_path 里面的内容如下面的注释所示:
接下来就可以分割图片了
因为tests_path 里面每个文件是 x.png 即文件名+后缀的方式。通过split的 '.' 分割成x和后缀名png的形式,[-2]代表取倒数第二个值,就可以将每个文件名x取出来,然后将路径拼接就可以存放到result里面
open图像的时候,也要注意,test_path 只是遍历tests_path 里面的文件,需要加上之前的predict路径才能正确的读取到每个待分割的图片
因为这里处理图像会改变size成480*480的形式,想要将输出的结果保持不变的话,在网络预测前将图像的大小保存下来就可以了。(注:这里的size和opencv里面的shape返回值是反过来的)
这里不清楚的可以通过调试,打印每个变量的内容看一下就可以了
接下来就是网络预测的部分,这里输出的size是(batch,channel,height,width),因为这里的batch是1,channel 灰度图片因此也是1,这里通过squeeze将1的维度删去,只需要图像的大小
下面是squeeze的用法
然后图像保存的话,要转到cpu上面 ,这一步不知道为啥,但是不加这一步会报错
最后就是保存图像了,将网络的结果二值化后,还原图像再保存就可以了
predict里面待预测的图片
result 里面分割好的图片
下面是 参考文章 博主的分割结果
对比发现,有些小的细节会丢失,但是大概的轮廓分割出来了
完整的项目可以在 这里 下载
- import numpy as np
- import torch
- import cv2
- from model import UNet
- from torchvision import transforms
- from PIL import Image
- import os
-
-
- # 预处理
- transform = transforms.Compose([
- transforms.Resize((480,480)), # 缩放图像
- transforms.ToTensor(),
- ])
-
- # 加载模型
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- net = UNet(in_channels=1, num_classes=1)
- net.load_state_dict(torch.load('Unet.pth', map_location=device))
- net.to(device)
-
- # 测试模式
- net.eval()
- # 读取所有图片路径
- tests_path = os.listdir('./predict/') # 获取 './predict/' 路径下所有文件,这里的路径只是里面文件的路径
- ''''
- print(tests_path)
- ['0.png', '1.png', '10.png', '11.png', '12.png', '13.png', '14.png',
- '15.png', '16.png', '17.png', '18.png', '19.png', '2.png', '20.png',
- '21.png', '22.png', '23.png', '24.png', '25.png', '26.png', '27.png',
- '28.png', '29.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png']
- '''
-
-
- with torch.no_grad(): # 预测的时候不需要计算梯度
- for test_path in tests_path: # 遍历每个predict的文件
- save_pre_path = './result/'+test_path.split('.')[-2] + '_res.png' # 将保存的路径按照原图像的后缀,按照数字排序保存
- img = Image.open('./predict/' +test_path) # 预测图片的路径
- width,height = img.size[0],img.size[1] # 保存图像的大小
- img = transform(img)
- img = torch.unsqueeze(img,dim = 0) # 扩展图像的维度
-
- pred = net(img.to(device)) # 网络预测
- pred = torch.squeeze(pred) # 将(batch、channel)维度去掉
- pred = np.array(pred.data.cpu()) # 保存图片需要转为cpu处理
-
- pred[pred >= 0] = 255 # 处理结果二值化
- pred[pred < 0] = 0
-
- pred = np.uint8(pred) # 转为图片的形式
- pred = cv2.resize(pred,(width,height),cv2.INTER_CUBIC) # 还原图像的size
- cv2.imwrite(save_pre_path, pred) # 保存图片
-