Mask-Aware Transformer 大空洞修复。
1、图像修复 Introduction
定义
图像修复(Image inpainting、Image completion、image hole-filling)指的是合成图像中缺失区域的过程,可以帮助恢复被遮挡或降质的部分。
在下图中,左图是原图,左图蓝色区域是mask区域(原图的mask区域是不传给模型的),右图是模型输出图。
[图片]
一般输入的是Image图(扣掉后)+Mask图(单通道),如下图。
[图片]
用途
移除物体 remove objects
举例,如上图。
这个任务常用的数据集是Places,包含约400个场景的图像。
生成物体 generate novel objects
举例如下图,生成人脸面部的眼镜、鼻子。
这个任务常用的数据集是CelebA-HQ人脸数据集。
[图片]
[图片]
Tips
GAN模型对于mask区域是进行“移除物体”还是“生成物体”,如果不外加干预(比如一些模型会加入人为互动绘制sketch),那么GAN模型的效果是取决于模型的。下图中,从Places数据集中训练出来的模型和从CelebA-HQ数据集中训练出来的模型对黑色mask区域的填充效果可见,效果取决于模型对数据的学习。
[图片]
难点
图像修复任务的难点大抵有如下:
方法
关键点
要想把图像修复做好,需要着重关注两点:
2、Related Work
Globally and Locally Consistent Image Completion
在关心全局语义的情况下,也注重局部细节。全局判别器网络将整个图像作为输入,而局部判别器网络仅将完成区域周围的小区域作为输入。训练两个判别器网络来确定图像是真实的还是由补全网络完成的,而训练补全网络来欺骗两个判别器网络。
缺点:色彩缺失,需要额外的后处理(快速行军和泊松图像混合)。
[图片]
[图片]
Generative Image Inpainting with Contextual Attention
使用上下文注意力层关注遥远空间位置的特征块。
双阶段图像生成,在Coarse Result之后再次精修。
全局判别器+局部判别器。
[图片]
Image Inpainting for Irregular Holes Using Partial Convolutions
Partial Convolutional Layer,包括一个masked和re-normalized的卷积操作,然后是一个mask-update step。
第一个证明在不规则形状的孔上训练图像绘制模型的有效性的人。
Free-Form Image Inpainting with Gated Convolution
引入门控卷积,为所有层中每个空间位置的每个通道学习动态特征选择机制,显著提高了自由形式掩模和输入的颜色一致性和修复质量。
提出了一种更实用的基于补丁的GAN鉴别器SN-PatchGAN,用于自由形式的图像修复。它简单、快速,并产生高质量的修复结果。
[图片]
Aggregated Contextual Transformations for High-Resolution Image Inpainting
建议学习高分辨率图像在绘画中的aggregated contextual transformations,这允许捕获信息 informative distant contexts 和rich patterns of interest for context reasoning进行上下文推理。
设计了一种新的掩模预测任务来训练适合图像绘制的discriminator。这样的设计迫使discriminator区分真实斑块和合成斑块的详细外观,这反过来又有利于生成器合成细粒度纹理。
[图片]
LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions
使用具有图像宽的感受野的快速傅里叶卷积fast Fourier convolutions (FFCs);
高感受野感知损失 (a high receptive field perceptual loss);
大面积mask训练。
[图片]
Co-ModGAN:Large Scale Image Completion via Co-Modulated Generative Adversarial Networks
提出Co-ModGAN,弥合了图像条件生成结构和最近的无条件调制生成结构之间的差距;
提出了新的P-IDS/U-IDS,用于对GAN的感知保真度进行稳健评估;
3、MAT:Mask-Aware Transformer
CVPR 2022 Best Paper Finalist, Oral
创新点
Transformer Body
本文对transformer模块进行了改进,一是删除了层归一化,二是采用融合学习(使用特征拼接)代替残差学习。
删除层归一化的原因:在大面积区域缺失的情况下,大部分的token是无效的,而层归一化会放大这些无效的token,从而导致训练不稳定;
替换残差连接的原因:残差连接鼓励模型学习高频内容,然而在刚开始大多数的token是无效的,在训练过程中没有适当的低频基础,很难直接学习高频细节,如果使用残差连接就会使优化变得困难。采用融合学习(使用特征拼接)代替残差学习,如下面的T图。
[图片]
[图片]
[图片]
Multi-Head Contextual Attention
为了处理大量的标记(对于512×512的图像,最多有4096个标记)和给定标记的低保真度(最多90%的标记是无用的),我们的注意力模块采用了位移窗口[36]和动态遮罩,能够利用少量可行的标记进行非局部交互。
注意力模块利用移位窗口和动态掩码,只使用有效的token进行加权求和。MCA输出是有效标记的加权和,如下图:
[图片]
Mask Updating Strategy
更新规则:只要当前窗口有一个token是有效的,经过注意力后,该窗口中的所有token都会更新为有效的。如果一个窗口中的所有token都是无效的,经过注意力后,它们仍然无效。
[图片]
Style Manipulation Module
设计了一个风格操作模块,使MAT具有多元化的生成。它通过在重建生成过程中使用额外的噪声输入改变卷积层的权值归一化来操纵输出。为了增强噪声输入的表示能力,我们强制图像条件样式sc从图像特征X和噪声无条件样式su中学习。
B是随机给的mask,由su和sc得到风格表达的s。
s将会改变权重W,从而让模型可以使用随机噪声作为输入,让模型可以有多元化生成。
[图片]
成绩
[图片]
4、图像修复 损失函数
重建损失 Reconstruction Loss
GAN(生成对抗网络)中的重建损失通常用于度量生成器生成的图像与真实图像之间的差异,帮助生成器学习生成更逼真的图像。在 GAN 中,生成器试图生成与真实图像相似的样本,而判别器则评估生成器生成的样本是否足够逼真。重建损失通常使用生成器生成的图像与对应的真实图像之间的差异来衡量。在实际应用中,可以根据任务和需求选择适当的损失函数,如 L1 损失、结构相似性损失(SSIM)等,下面用均方误差(MSE)作为重建损失举例。此外还有一些难以解释理解的重建损失:https://www.zhihu.com/question/521284760/answer/2384076383
import torch
import torch.nn as nn
generated_image = torch.rand((16, 3, 64, 64)) # 16张3通道64x64的随机生成图像
real_image = torch.rand((16, 3, 64, 64)) # 16张3通道64x64的随机真实图像
reconstruction_loss = nn.MSELoss() # 使用均方误差损失
loss_value = reconstruction_loss(generated_image, real_image)
print(“重建损失值:”, loss_value.item())
对抗性损失 Adversarial Loss
在生成对抗网络(GAN)中,对抗性损失是用来训练判别器(Discriminator)和生成器(Generator)之间竞争的损失函数。它鼓励生成器生成逼真的样本,同时使判别器能够区分生成的样本和真实样本。对抗性损失通常是使用交叉熵损失函数来衡量生成样本被正确分类为真实样本的程度。
import torch
import torch.nn as nn
discriminator_predictions = torch.rand((16, 1)) # 判别器对16个样本的预测结果
generated_labels = torch.zeros((16, 1))
adversarial_loss = nn.BCEWithLogitsLoss() # 使用二进制交叉熵损失
loss_value = adversarial_loss(discriminator_predictions, generated_labels)
print(“对抗性损失值:”, loss_value.item())
在上述示例中,我们使用了带有 logits 的二进制交叉熵损失(BCEWithLogitsLoss),将判别器的预测与生成样本的标签进行比较。对抗性损失的目标是使判别器能够正确区分生成样本和真实样本,同时促使生成器生成逼真的样本,从而使两者之间形成平衡竞争关系。
感知损失 Perceived Loss
在图像修复中,感知损失是一种用于训练生成对抗网络(GAN)的损失函数,它帮助网络学习更好地合成逼真的修复图像。感知损失通过比较生成图像和真实图像之间的特征表示来量化生成图像的质量。
下面是一个用PyTorch演示感知损失在图像修复中的示例代码片段:
import torch
import torch.nn as nn
import torchvision.models as models
class PerceptualLoss(nn.Module):
def init(self):
super(PerceptualLoss, self).init()
self.vgg = models.vgg19(pretrained=True).features
self.layers = {
‘3’: ‘relu1_2’, # Conv3_2 -> ReLU1_2
‘8’: ‘relu2_2’, # Conv8_2 -> ReLU2_2
‘17’: ‘relu3_3’, # Conv17_3 -> ReLU3_3
‘26’: ‘relu4_3’ # Conv26_3 -> ReLU4_3
}
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, x, y):
x_features = self.get_features(x)
y_features = self.get_features(y)
loss = 0
for layer_name in self.layers:
loss += nn.functional.mse_loss(x_features[layer_name], y_features[layer_name])
return loss
def get_features(self, x):
features = {}
prev_x = x
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.layers:
features[self.layers[name]] = x
if name == '26':
break
return features
criterion = PerceptualLoss()
fake_image = torch.randn(1, 3, 256, 256) # 生成的修复图像
real_image = torch.randn(1, 3, 256, 256) # 真实图像
loss = criterion(fake_image, real_image)
print(“Perceptual Loss:”, loss.item())
在这个示例中,PerceptualLoss 类从预训练的VGG19模型中提取了不同层的特征,并计算修复图像和真实图像之间的感知损失。这有助于生成对抗网络学习将合成图像的特征与真实图像的特征匹配,从而提高修复图像的质量。
风格损失 Style Loss
风格损失是一种用于训练生成对抗网络(GAN)的损失函数,它有助于确保修复图像在视觉上与原始图像在风格上保持一致。风格损失通过比较生成图像与原始图像之间的特定风格特征,如纹理、颜色和形状等,来量化生成图像的风格相似性。
以下是一个使用PyTorch编写的示例程序,演示如何计算图像修复中的风格损失:
import torch
import torch.nn as nn
import torchvision.models as models
class StyleLoss(nn.Module):
def init(self):
super(StyleLoss, self).init()
self.vgg = models.vgg19(pretrained=True).features
self.layers = {
‘3’: ‘relu1_2’, # Conv3_2 -> ReLU1_2
‘8’: ‘relu2_2’, # Conv8_2 -> ReLU2_2
‘17’: ‘relu3_3’, # Conv17_3 -> ReLU3_3
‘26’: ‘relu4_3’ # Conv26_3 -> ReLU4_3
}
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, x, y):
x_features = self.get_features(x)
y_features = self.get_features(y)
loss = 0
for layer_name in self.layers:
loss += nn.functional.mse_loss(self.gram_matrix(x_features[layer_name]),
self.gram_matrix(y_features[layer_name]))
return loss
def gram_matrix(self, input):
b, c, h, w = input.size()
features = input.view(b, c, h * w)
gram = torch.bmm(features, features.transpose(1, 2))
gram = gram / (c * h * w)
return gram
def get_features(self, x):
features = {}
prev_x = x
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.layers:
features[self.layers[name]] = x
if name == '26':
break
return features
criterion = StyleLoss()
fake_image = torch.randn(1, 3, 256, 256) # 生成的修复图像
original_image = torch.randn(1, 3, 256, 256) # 原始图像
loss = criterion(fake_image, original_image)
print(“Style Loss:”, loss.item())
在这个示例中,StyleLoss 类从预训练的VGG19模型中提取了不同层的特征,并计算修复图像与原始图像之间的风格损失。风格损失有助于生成对抗网络学习将修复图像的风格与原始图像的风格保持一致,从而提高修复图像的视觉品质。
感知损失和风格损失都是用于训练生成模型的损失函数,但它们分别强调了内容和风格两个不同的方面。
5、评价指标 Evaluation Metrics