前言:
之所以想到用 pytorch ,主要是因为不想在网络模块中调用 opencv 的函数。
调用 opencv 函数的基本步骤如下:先把 pytorch 的 tensor 转到 cpu 上,然后转换成 numpy,再调整到 uint8 格式,然后才能调用 cv2.erode
。 麻烦不说,还无法充分利用 GPU 的并行加速,同时阻断了 gradient 的传播路径,因此有必要用 pytorch 。
形态学是基于形状处理图像的一组广泛的图像处理运算。形态学运算将结构元素应用于输入图像,从而创建相同大小的输出图像。在形态学运算中,输出图像中每个像素的值基于输入图像中对应像素与其相邻像素的比较。
最基本的形态学运算是膨胀和腐蚀。膨胀指将像素添加到图像中对象的边界,而腐蚀指删除对象边界上的像素。对图像中对象添加或删除的像素数量取决于用于处理图像的结构元素的大小和形状。在形态学膨胀和腐蚀运算中,输出图像中任何给定像素的状态通过对输入图像中的对应像素及其相邻像素应用规则来确定。用于处理像素的规则将运算定义为膨胀或腐蚀。下表列出了膨胀和腐蚀的规则。
膨胀和腐蚀的规则
操作 | 规则 | 示例(原始图像和处理后的图像) |
---|---|---|
膨胀 | 输出像素的值是邻域中所有像素的最大值。在二值图像中,如果一个像素的任何相邻像素的值为 形态学膨胀使对象更加明显可见并填充对象中的小孔。线条看起来更粗,填充的形状看起来更大。 |
|
腐蚀 | 输出像素的值是邻域中所有像素的最小值。在二值图像中,如果一个像素的任何相邻像素的值为 形态学腐蚀去除了孤立像素和细线,从而只留下实质对象。剩余线条看起来更细,形状更小。 |
|
下图说明了二值图像的膨胀。结构元素如何定义感兴趣的像素的邻域,该邻域带圆圈。膨胀函数将适当的规则应用于邻域中的像素,并为输出图像中的对应像素赋值。在图中,形态学膨胀函数将输出像素的值设置为 1
,因为由结构元素定义的邻域中的元素之一处于打开状态。有关详细信息,请参阅Structuring Elements。
二值图像的形态学膨胀
下图说明灰度图像的这种处理。膨胀函数将规则应用于感兴趣的圈中像素的邻域。输出图像中对应像素的值被指定为所有邻域像素中的最高值。在图中,输出像素的值是 16
,因为它是由结构元素定义的邻域中的最高值。
灰度图像的形态学膨胀
首先介绍 PyTorch 中一个很有用的函数:unfold
,它的作用是将 tensor 按照固定的 step 和 kernel size 拆分成 patch,每个 patch 为 kernel 覆盖的像素,下面举例说明。
- def tensor_erode(self,bin_img, ksize=3): # 已测试
- #先为原图加入 padding,防止腐蚀后图像尺寸缩小
- B, C, H, W = bin_img.shape
- pad = (ksize - 1) // 2
- bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)
- # 将原图 unfold 成 patch
- patches = bin_img.unfold(dimension=2, size=ksize, step=1)
- patches = patches.unfold(dimension=3, size=ksize, step=1)
- # B x C x H x W x k x k
- # 取每个 patch 中最小的值
- eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1)
- return eroded
下面是膨胀,按照原理将min改为max即可,这个膨胀的功能暂时还没有测试,上面那个测试是可以使用的。
- def tensor_dilate(self,bin_img, ksize=3): #
- # 首先为原图加入 padding,防止图像尺寸缩小
- B, C, H, W = bin_img.shape
- pad = (ksize - 1) // 2
- bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)
- # 将原图 unfold 成 patch
- patches = bin_img.unfold(dimension=2, size=ksize, step=1)
- patches = patches.unfold(dimension=3, size=ksize, step=1)
- # B x C x H x W x k x k
- # 取每个 patch 中最小的值,i.e., 0
- dilate, _ = patches.reshape(B, C, H, W, -1).max(dim=-1)
- return dilate