相信大家在学习与图像任务相关的神经网络时,经常会见到这样一个预处理方式。
- self.to_tensor_norm = transforms.Compose([
- transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ])
具体原理及作用稍后解释,不知道大家有没有想过,将这样一个经过改变的图像数据输入到网络中,那么输出的结果也是这种类似改动过的,那岂不是真实的数据了。
所以一般会有个后处理的代码,如下:
- def tensor2img(img):
- img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1)* 127.5)
- img = img.clip(min=0, max=255).astype(np.uint8)
- return img
为什么这样就可以将改动过的数据恢复原样了,后处理的代码看着也不像预处理的逆过程啊。
先来分析一下代码,了解其处理过程,最后再推理出这两个互为逆过程。
transforms.ToTensor()
transforms.ToTensor()
是PyTorch中的一个图像转换方法,用于将PIL图像或numpy数组转换为PyTorch张量。具体来说,它会执行以下操作:
下面是我翻译的源码的注释,包含了输入的要求:
torchvision.transforms.ToTensor
类用于将 PIL 图像或 numpy 数组转换为张量。这个转换不支持 torchscript。
将一个 PIL 图像或 numpy 数组(大小为 H x W x C,其中 H 表示高度,W 表示宽度,C 表示通道数)的像素值范围从 [0, 255] 转换为范围在 [0.0, 1.0] 的 torch.FloatTensor,其形状为 (C x H x W)。这种转换只有在以下情况下才会进行:
np.uint8
。(因为uint8的类型的取值范围是0-255)在其他情况下,转换后的张量将不会进行缩放。
两者内容互为补充,相信足够理解这个代码了,如果不够理解,没事,我自己写个代码解释:
上述数值被分别除以255得到转换后的张量,现在应该有更直观的理解了。
transforms.Normalize()
transforms.Normalize()
是PyTorch中的一个图像转换方法,用于对张量进行标准化处理。具体来说,它执行以下操作:
在给定的示例中,(0.5, 0.5, 0.5)
表示每个通道的均值,(0.5, 0.5, 0.5)
表示每个通道的标准差。这个转换将图像的每个通道的像素值从0到1的范围,调整到-1到1的范围内。
上述的预处理的两个步骤可以概括为归一化或者标准化,为什么需要这两个步骤呢,我举例子加以说明
加速收敛:
提高模型性能:
稳定性:
防止过拟合:
适应不同初始化:
节省计算资源:
改善梯度下降的效率:
img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1)* 127.5)
这行代码的作用是将PyTorch张量转换为numpy数组,并执行以下操作:
img.permute(0, 2, 3, 1)
:这一步是对张量的维度进行重新排列,将通道维度移到最后一个维度上。这通常是因为在PyTorch中,图像的通道维度是第二个维度,而在numpy数组中,通常是最后一个维度。所以这一步是为了将数据转换为numpy数组后,通道维度的顺序与numpy数组的约定相匹配。
.cpu().numpy()
:这一步将PyTorch张量移动到CPU上,并将其转换为numpy数组。通常,在GPU上进行计算后,需要将数据移回CPU上才能调用numpy方法。
+ 1
:这一步将数组中的所有元素加1,将范围从[-1, 1]映射到[0, 2]。
* 127.5
:这一步将数组中的所有元素乘以127.5,将范围从[0, 2]映射到[0, 255],将数据重新缩放到uint8范围内。
np.round()
:这一步对数组中的所有元素执行四舍五入操作,将浮点数转换为整数。
综合起来,这行代码的作用是将PyTorch张量(范围在[-1, 1]之间)转换为numpy数组,并将其值重新映射到uint8范围内(0-255),并将浮点数转换为整数。
img = img.clip(min=0, max=255).astype(np.uint8)
这行代码的作用是确保numpy数组中的数值范围在0到255之间,并将其类型转换为无符号8位整数(uint8),以便表示图像像素值。
先把代码放一起进行比较
- 预处理:
- self.to_tensor_norm = transforms.Compose([
- transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ])
-
- 后处理:
- def tensor2img(img):
- img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1)* 127.5)
- img = img.clip(min=0, max=255).astype(np.uint8)
- return img
下面是推导过程:
完结撒花!
不足之处还请大家指正。