官方文档
首先Pytorch中grid_sample函数的接口声明如下:
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)
这里的input和output就是输入的图片,或者是网络中的feature map。关键的处理过程在于grid,grid的最后一维的大小为2,即表示input中pixel的位置信息 (x,y) ,这里一般会将x和y的取值范围归一化到 [−1,1] 之间, (−1,−1) 表示input左上角的像素的坐标,(1,1) 表示input右下角的像素的坐标,对于超出这个范围的坐标(x,y),函数将会根据参数_padding_mode_的设定进行不同的处理。
对于mode='bilinear’参数,则定义了在input中指定位置的pixel value中进行插值的方法,为什么需要插值呢?因为前面我们说了,grid中表示的位置信息x和y的取值范围在 [−1,1] 之间,这就意味着我们要根据一个浮点型的坐标值在input中对pixel value进行采样,mode有nearest和bilinear两种模式。
双线性插值:
举例:
import torch
from torch.nn import functional as F
inp = torch.ones(1, 128, 4, 4)
# 目的是得到一个 长宽为20的tensor
out_h = 20
out_w = 20
grid_x, grid_y = torch.meshgrid(
torch.linspace(-1, 1, out_h),
torch.linspace(-1, 1, out_w)
)
# grid 最后一维度表示在input采样的位置(x,y),y表示图像纵轴,x表示横轴,grid顺序应该先x递增,后y递增
grid = torch.stack((grid_y, grid_x), dim=-1).unsqueeze(0) # (out_h, out_w, 2)
# F.grid_sample -> input:(N,C,Hin,Win), grid:(N,Hout,Wout,2), output:(N,C,Hout,Wout)
# outp = F.grid_sample(features, grid, align_corners=True, mode='bilinear')
outp = F.grid_sample(inp, grid, align_corners=True, mode='nearest')
print(outp.shape) # torch.Size([1, 128, 20, 20])
对图像,特征进行采样用以上grid才不会图像位置错误