• VQGAN理论加代码一对一详解,小白向解析


    最近在看图像生成相关论文,记录一下学习内容。感觉只看论文有点干巴,所以理论代码一对一上。

    整体网络框架

    VQGAN (Vector Quantized Generative Adversarial Network) 是一种基于 GAN 的生成模型,可以将图像或文本转换为高质量的图像。

    • VQ (Vector Quantization)是一种数据压缩技术,是指将连续数据表示为离散化的向量。输入的图像或文本被映射到 VQ 空间中的离散化向量表示,然后,离散化向量然后被送到 GAN 模型中进行图像生成。(参见上图的下半部分)在训练过程中,VQGAN 模型会优化两个损失函数:一个用于量化误差(即离散化向量和连续值之间的误差),另一个用于生成器和判别器之间的对抗损失。
    • GAN 是由生成器和判别器两个模型组成的,生成器负责生成图像,判别器负责判断生成的图像是否为真实的图像。在训练过程中,生成器和判别器相互博弈,不断优化各自的参数,以使生成的图像更接近真实图像。

    在这里插入图片描述
    上图是论文的总体模型图。下面具体来看看如何实现的。

    训练过程

    VQGAN整体模型需要两步训练。

    • 第一步通过自监督学习训练CNN Encoder,CNN Decoder,和Codebook;
    • 第二步在已训练好的CNN Encoder和Codebook基础上,通过将code随机替换加入强噪声,用Transformer去重建其code组,来提高Transformer的泛化能力。

    第一步——CNN Encoder,CNN Decoder,Codebook

    如上图所示,从一张输入图片开始(一般是RGB图片) x ∈ R H × W × 3 x \in \mathbb{R}^{H\times W×3} xRH×W×3,其通过CNN Encoder编码后得到中间特征变量 z ^ ∈ R h × w × n z \hat z \in \mathbb{R}^{h\times w×n_z} z^Rh×w×nz。这时再引入一个codebook,注意,如果是普通的AutoEncoder,则会将 z ^ \hat z z^ 直接送入解码器中进行图像重建。而在VQVAE/VQGAN中,会将 z ^ \hat z z^进行进一步离散化编码成 z q ∈ R h × w × n z z_q\in \mathbb{R}^{h\times w×n_z} zqRh×w×nz

    具体做法为:预先生成一个离散数值的codebook Z = { z k } k = 1 K , z k ∈ R n z \mathcal Z=\{z_k\}_{k=1}^{K},z_k \in \mathbb{R}^{n_z} Z={zk}k=1K,zkRnz,在 z ^ \hat z z^ 的每一个编码位置都去 Z \mathcal Z Z中去寻找其距离最近的code,生成具有相同维度的变量。特别注意,这里 z ^ , z q \hat z,z_q z^,zq Z \mathcal Z Z中的单个编码特征的维度都为 n z n_z nz。这一步离散编码的过程就叫做“quantization”, 也就是上面的那个公式。

    这样一来,就可以在已经数值离散化的 z q z_q zq基础上使用CNN Decoder进行解码:
    x ^ = G ( z q ) = G ( q ( E ( x ) ) ) \hat x=G(z_q)=G(q(E(x))) x^=G(zq)=G(q(E(x)))

    整个过程的自监督损失如下:
    L V Q ( E , G , Z ) = ∣ ∣ x − x ^ ∣ ∣ 2 + ∣ ∣ s g [ E ( x ) ] − z q ∣ ∣ 2 + ∣ ∣ s g ( z q ) − E ( x ) ∣ ∣ 2 \mathcal L_{VQ}(E,G,Z)=||x-\hat x||^2+||sg[E(x)]-z_q||^2+||sg(z_q)-E(x)||^2 LVQ(E,G,Z)=∣∣xx^2+∣∣sg[E(x)]zq2+∣∣sg(zq)E(x)2其中,上式中的第一项 L r e c \mathcal L_{rec} Lrec 为重建损失(reconstruction loss) s g [ ⋅ ] sg[·] sg[] 为梯度终止操作(stop-gradient operation),其目的在于保证神经网络梯度可以正常回传,而不受离散编码的影响。因此在codebook的搭建过程中,我们看到由 z ^ \hat z z^得到 z q z_q zq之后,先计算出公式中后两项损失,然后又增加了一步detach操作。

    loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
    z_q = z + (z_q - z).detach()
    
    • 1
    • 2

    这么一来,在其后面计算 L r e c \mathcal L_{rec} Lrec,即公式的第一项中, z q z_q zq的梯度可以顺利复制到 z ^ \hat z z^上,而不受离散编码过程的干扰。除了这个重建过程使用的自监督损失外,还加入了GAN中的对抗loss。文章里没有具体写对抗loss的类型。通过源码可以发现使用的是hinge loss。对于判别器而言,其损失函数可以笼统地表示为:
    L G A N ( { E , G , Z } , D ) = l o g D ( x ) + l o g ( 1 − D ( x ^ ) ) \mathcal L_{GAN}(\{E,G,\mathcal Z\}, D)=logD(x)+log(1-D(\hat x)) LGAN({E,G,Z},D)=logD(x)+log(1D(x^))

    所以总的误差可以写成:
    L = L V Q + λ L G A N \mathcal L = \mathcal L_{VQ}+\lambda \mathcal L_{GAN} L=LVQ+λLGAN

    总结来说就是:
    x → z ^ → z q → x ^ x\to \hat z\to z_q\to \hat x xz^zqx^
    下面主要来看看这三部分的代码
    CNN Encoder, CNN Decoder是一种基于UNet的代码结构,具体细节可以从原文中获取,这里不在细说

    CNN Encoder

    class Encoder(nn.Module):
        def __init__(self, args):
            super(Encoder, self).__init__()
            channels = [128, 128, 128, 256, 256, 512]
            attn_resolutions = [16]
            num_res_blocks = 2
            resolution = 256
            layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]
            for i in range(len(channels)-1):
                in_channels = channels[i]
                out_channels = channels[i + 1]
                for j in range(num_res_blocks):
                    layers.append(ResidualBlock(in_channels, out_channels))
                    in_channels = out_channels
                    if resolution in attn_resolutions:
                        layers.append(NonLocalBlock(in_channels))
                if i != len(channels)-2:
                    layers.append(DownSampleBlock(channels[i+1]))
                    resolution //= 2
            layers.append(ResidualBlock(channels[-1], channels[-1]))
            layers.append(NonLocalBlock(channels[-1]))
            layers.append(ResidualBlock(channels[-1], channels[-1]))
            layers.append(GroupNorm(channels[-1]))
            layers.append(Swish())
            layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1))
            self.model = nn.Sequential(*layers)
            
        def forward(self, x):
            return self.model(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    具体的模块定义可以阅读源代码,这个都不难理解。

    CNN Decoder

    class Decoder(nn.Module):
        def __init__(self, args):
            super(Decoder, self).__init__()
            channels = [512, 256, 256, 128, 128]
            attn_resolutions = [16]
            num_res_blocks = 3
            resolution = 16
    
            in_channels = channels[0]
            layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1),
                      ResidualBlock(in_channels, in_channels),
                      NonLocalBlock(in_channels),
                      ResidualBlock(in_channels, in_channels)]
    
            for i in range(len(channels)):
                out_channels = channels[i]
                for j in range(num_res_blocks):
                    layers.append(ResidualBlock(in_channels, out_channels))
                    in_channels = out_channels
                    if resolution in attn_resolutions:
                        layers.append(NonLocalBlock(in_channels))
                if i != 0:
                    layers.append(UpSampleBlock(in_channels))
                    resolution *= 2
                   
            layers.append(GroupNorm(in_channels))
            layers.append(Swish())
            layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1))
            self.model = nn.Sequential(*layers)
    
        def forward(self, x):
            return self.model(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    Codebook

    我最开始看的时候,最不明白的地方就是这个codebook,一直在想,这兄弟是哪蹦出来的。其实就是另外定义的一个网络,说白了甚至算不上一个网络就是一个nn.Embedding(),还是之前没看VQVAE的锅。

    class Codebook(nn.Module):
        def __init__(self, args):
            super(Codebook, self).__init__()
            self.num_codebook_vectors = args.num_codebook_vectors
            self.latent_dim = args.latent_dim
            self.beta = args.beta
            self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
            self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)
    
        def forward(self, z):
            z = z.permute(0, 2, 3, 1).contiguous()
            z_flattened = z.view(-1, self.latent_dim)
            d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
                torch.sum(self.embedding.weight**2, dim=1) - \
                2*(torch.matmul(z_flattened, self.embedding.weight.t()))
    
            min_encoding_indices = torch.argmin(d, dim=1)
            z_q = self.embedding(min_encoding_indices).view(z.shape)
            loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
            z_q = z + (z_q - z).detach()
            z_q = z_q.permute(0, 3, 1, 2)
    
            return z_q, min_encoding_indices, loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    第二步——Transformer 训练

    经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。
    压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。而恰好,Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。
    VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第i 步,Transformer会根据前i−1 个像素 s < i s_{s<i生成第 i i i 个像素 s i s_i si.

    在这里插入图片描述

    来看具体实现——训练过程

    现在进入第二步,这篇论文毕竟是个图像生成的任务,注意之前的三个零件已经训练好不动了,现在我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。那么这组code怎么来的?这就是Transformer发挥作用的地方了。该工作使用的Transformer模型为著名的GPT-2。迁移到VQGAN中,即可理解为先预测一个code,再一步步地通过已经预测好的code去推断下一个code。

    code都是从训练好的codebook Z \mathcal Z Z中寻找,就像写文章一样,你有词典了,现在你要从词典中一个字一个字的写成一篇新文章

    为了训练Transformer,

    • 将输入图片 x ∈ R H × W × 3 x \in \mathbb{R}^{H\times W×3} xRH×W×3,通过CNN Encoder编码后得到中间特征变量 z ^ ∈ R h × w × n z \hat z \in \mathbb{R}^{h\times w×n_z} z^Rh×w×nz,再将 z ^ \hat z z^进行进一步离散化编码成 z q ∈ R h × w × n z z_q\in \mathbb{R}^{h\times w×n_z} zqRh×w×nz,[注意这部分都是用上一步训练好的模型,这里只做前传,不做梯度回传训练],
    • z q z_q zq 被展平到空间 R h w × n z \mathbb{R}^{hw×n_z} Rhw×nz ,这样得到了 h w hw hw 个排列好的维度为 n z n_z nz的code。
    • 随机将其中的一部分code替换为随机生成的相同维度的向量,输入transformer模型,也即是给特征中加入噪声。接着进行训练,训练损失函数为cross-entropy交叉熵损失。

    假设被替换后的code组合的索引为modified_indices,原本 z q z_q zq的code索引为unmodified_indices,那么Transformer的学习过程即为:喂入modified_indices,通过训练学习重构出unmodified_indices。
    L t r a n s f o r m e r = E x ∼ p ( x ) [ − l o g p ( s ) ] \mathcal L_{transformer}=\mathbb E_{x\sim p(x)}[-logp(s)] Ltransformer=Exp(x)[logp(s)]

    代码具体实现如下:

    """
    首先得到由x前传得到的unmodified_indices
    """
    
    sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token 
    # (B, 1), sos_token是一个整数,表示从第几个token开始预测,一般为0
    
    mask = torch.bernoulli(self.pkeep * torch.ones(unmodified_indices.shape, device=unmodified_indices.device)) 
    # (B, h*w), 元素都为0和1,0的是mask掉的元素,1是保留的元素(比例为pkeep)
    
    mask = mask.round().to(dtype=torch.int64)
    random_indices = torch.randint_like(indices, self.transformer.config.vocab_size)    
    # (B, h*w), 生成一些任意的indices,用来填充被遮挡的部分
    modified_indices= mask * unmodified_indices+ (1 - mask) * random_indices 
    # (B, h*w), mask为1(未遮挡)部分仍然保留原始indices,mask为0(遮挡)部分用random_indices填充
    modified_indices= torch.cat((sos_tokens, modified_indices), dim=1) 
    # (B, h*w+1),将0放到第一个indice前面
    targets = unmodified_indices
    logits, _ = self.transformer(modified_indices[:, :-1]) 
    # logits: (B, h*w, num_codebook_vectors), 意思是h*w个indices处,预测出来的对应每一个codebook_vector的概率
    
    
    """
    然后再由logits和targets之间计算交叉熵损失
    """
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    注意这是训练的过程,不是生成的过程。在VQGAN无条件生成图片的过程中,没有任何先验条件,CNN Encoder直接被弃用。我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。

  • 相关阅读:
    LAMP架构-nginx并发优化
    LCD1602指定位置显示字符串-详细版
    3、在docker 容器中安装tomcat
    C++DAY47
    (完美方案)解决mfc140u.dll文件丢失问题,快速且有效的修复
    使用jenkins自动化部署
    QT打开网页或者资源管理器:QDesktopServices以及QSettings 用法
    JDBC工具类
    Chap12.1圆通荣达,进退自如
    zabbix mysql监控项
  • 原文地址:https://blog.csdn.net/qq_42208244/article/details/132889927