squeeze和unsqueeze是torch张量常用的降维与升维的一种方式,但这种方式只能增添或减少大小为1的维度,如下:
x1 = torch.randn(1, 8, 256, 256)
x1 = torch.squeeze(x1,dim=0)
print(x1.shape) # torch.Size([8, 256, 256])
x2 = torch.randn(8, 1, 256, 256)
x2 = torch.squeeze(x2,dim=1)
print(x2.shape) # torch.Size([8, 256, 256])
x1 = torch.randn(8, 256, 256)
x1 = torch.unsqueeze(x1,dim=0)
print(x1.shape) # torch.Size([1, 8, 256, 256])
x2 = torch.randn(8, 256, 256)
x2 = torch.unsqueeze(x2,dim=1)
print(x2.shape) # torch.Size([8, 1, 256, 256])
但如果维度大小不为1,squeeze就无效了。
降维:可以使用torch.mean()函数来对维度X进行求平均值,相当于将维度X的所有通道合并为一个单一的通道。
升维:可以使用expand()函数对需要的尺寸进行扩展(其他维度传递-1作为参数,表示在那个维度不进行扩展)。
x1 = torch.randn(2, 8, 256, 256)
x1 = torch.mean(x1, dim=0)
print(x1.shape) # torch.Size([8, 256, 256])
x2 = torch.randn(8, 3, 256, 256)
x2 = torch.mean(x2, dim=1)
print(x2.shape) # torch.Size([8, 256, 256])
x3 = torch.randn(8, 256, 256)
x3 = x3.unsqueeze(0).expand(4,-1,-1,-1)
print(x3.shape) # torch.Size([4, 8, 256, 256])
x4 = torch.randn(16, 256, 256)
x4 = x4.unsqueeze(1).expand(-1, 8, -1, -1)
print(x4.shape) # torch.Size([16, 8, 256, 256])