• AE(自动编码器)与VAE(变分自动编码器)的区别和联系?


    他们各自的概念看以下链接就可以了:https://blog.csdn.net/weixin_43135178/category_11543123.html

     这里主要谈一下他们的区别?


    先说结论:

    • VAE是AE的升级版,VAE也可以被看作是一种特殊的AE
    • AE主要用于数据的压缩与还原,VAE主要用于生成
    • AE是将数据映直接映射为数值code(确定的数值),而VAE是先将数据映射为分布,再从分布中采样得到数值code。
    • 损失函数和优化目标不同


    一、AE(Auto Encoder, 自动编码器

    1、AE的结构

    如上图所示,自动编码器主要由两部分组成:编码器(Encoder)和解码器(Decoder)。编码器和解码器可以看作是两个函数,一个用于将高维输入(如图片)映射为低维编码(code),另一个用于将低维编码(code)映射为高维输出(如生成的图片)。这两个函数可以是任意形式,但在深度学习中,我们用神经网络去学习这两个函数。

    这时候我们只要拿出Decoder部分,随机生成一个code然后输入,就可以得到一张生成的图像。但实际上这样的生成效果并不好(下面解释原因),因此AE多用于数据压缩,而数据生成则使用下面所介绍的VAE更好。

    2、AE的缺陷

    由上面介绍可以看出,AE的Encoder是将图片映射成“数值编码”,Decoder是将“数值编码”映射成图片。这样存在的问题是,在训练过程中,随着不断降低输入图片与输出图片之间的误差,模型会过拟合,泛化性能不好。也就是说对于一个训练好的AE,输入某个图片,就只会将其编码为某个确定的code,输入某个确定的code就只会输出某个确定的图片,如果这个latent code来自于没见过的图片,那么生成的图片也不会好。下面举个例子来说明:

    假设我们训练好的AE将“新月”图片encode成code=1(这里假设code只有1维),将其decode能得到“新月”的图片;将“满月”encode成code=10,同样将其decode能得到“满月”图片。这时候如果我们给AE一个code=5,我们希望是能得到“半月”的图片,但由于之前训练时并没有将“半月”的图片编码,或者将一张非月亮的图片编码为5,那么我们就不太可能得到“半月”的图片。因此AE多用于数据的压缩和恢复,用于数据生成时效果并不理想。

    3、AE的代码实现

    3.1)AE encoder + decoder + AE的模型

    1. import torch
    2. from torch import nn
    3. from torch.autograd import Variable
    4. # Define the encoder and decoder networks
    5. class Encoder(nn.Module):
    6. def __init__(self, input_dim, latent_dim):
    7. super(Encoder, self).__init__()
    8. self.fc1 = nn.Linear(input_dim, latent_dim)
    9. def forward(self, x):
    10. x = torch.relu(self.fc1(x))
    11. return x
    12. class Decoder(nn.Module):
    13. def __init__(self, latent_dim, output_dim):
    14. super(Decoder, self).__init__()
    15. self.fc1 = nn.Linear(latent_dim, output_dim)
    16. def forward(self, x):
    17. x = torch.sigmoid(self.fc1(x))
    18. return x
    19. class Autoencoder(nn.Module):
    20. def __init__(self, input_dim, latent_dim):
    21. super(Autoencoder, self).__init__()
    22. self.encoder = Encoder(input_dim, latent_dim)
    23. self.decoder = Decoder(latent_dim, input_dim)
    24. def forward(self, x):
    25. x = self.encoder(x)
    26. x = self.decoder(x)
    27. return x

    3.2)AE模型训练过程

    1. # Define the model, loss function, and optimizer
    2. input_dim = 784 # For MNIST images
    3. latent_dim = 32
    4. model = Autoencoder(input_dim, latent_dim)
    5. criterion = nn.BCELoss()
    6. optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    7. # Training loop
    8. for epoch in range(num_epochs):
    9. for data in dataloader:
    10. img, _ = data
    11. img = img.view(img.size(0), -1)
    12. img = Variable(img)
    13. # Forward pass
    14. output = model(img)
    15. loss = criterion(output, img)
    16. # Backward pass and optimization
    17. optimizer.zero_grad()
    18. loss.backward()
    19. optimizer.step()
    20. print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

    3.3)AE模型的推理过程

    1. # Evaluate the model
    2. model.eval()
    3. with torch.no_grad():
    4. for data in dataloader:
    5. img, _ = data
    6. img = img.view(img.size(0), -1)
    7. img = Variable(img)
    8. output = model(img)

    3.4)AE怎么通过latent code生成新的图像

    1. # Initialize the autoencoder
    2. input_dim = 784 # Assuming input is a flattened 28x28 image
    3. latent_dim = 20
    4. autoencoder = Autoencoder(input_dim, latent_dim)
    5. latent_vector = torch.randn(latent_dim)
    6. # Generate a new image by passing the latent vector through the decoder
    7. with torch.no_grad():
    8. generated_image = autoencoder.decoder(latent_vector)

    4、如何解决AE的问题呢?

    这时候我们转变思路,不将图片映射成“数值编码”,而将其映射成“分布”。还是刚刚的例子,我们将“新月”图片映射成μ=1的正态分布,那么就相当于在1附近加了噪声,此时不仅1表示“新月”,1附近的数值也表示“新月”,只是1的时候最像“新月”。将"满月"映射成μ=10的正态分布,10的附近也都表示“满月”。那么code=5时,就同时拥有了“新月”和“满月”的特点,那么这时候decode出来的大概率就是“半月”了。这就是VAE的思想。

    二、VAE(Variational Auto-Encoder, 变分自动编码器)

    1、VAE的结构

    vae和常规自编码器不同的地方在于它的encoder的output不是一个latent vector让decoder去做decode,而是从某个连续的分布里(常见的是高斯分布)采样得到一个随机数or随机向量,然后decode再去针对这个scaler做解码。

    不使用连续的AE这种分布,而是使用符合某种分布的VAE,这是因为:

    常规的ae的潜在空间的规律性是一个难点,它取决于初始空间中数据的分布、潜在空间的维度和编码器的架构等。因此,我们基本不可能就认为ae的latent vector的distribution和我们产生随机数的distribution是一个distribution(不可能),那就很尴尬了,假设latent vector的取值范围都在0~1之间,然后我们产生了一个包含了大量的负数的随机数让decoder去decode,那么decoder压根就decode不出什么正常的东西,training阶段压根就没见过嘛。想一想,潜在空间中的latent vector 的分布规律很难知道也是正常的,因为常规的自动编码器的任务中没有任何东西被训练来强制获得这样的规律(但是vae就会假设latent vector服从高斯分布)自动编码器被训练成以尽可能少的损失进行编码和解码,压根就不care latent vector服从什么分布。那么自然,我们是不可能使用一个预定义的随机分布产生随机的input然后又期望decoder能够decode出有意义的东西的.

    既然我们不知道latent vector服从什么分布,我们就直接人为对其进行约束满足某种预定义的分布,这个预定义的分布和我们产生随机数的分布保持一致,不就完美解决问题了吗?

    所以通过VAE求出均值和方差,然后使用重参数化技巧在得到的这个分布中进行采样,就可以得到符合此分布的latent vector了。

    为什么使用重参数化?

    具体来说,在不使用重参数化的情况下,模型会直接从参数化的分布(例如,正态分布,由均值 μ 和方差 σ2 参数化)中采样,这使得梯度无法直接通过采样过程回传。重参数化技巧通过引入一个不依赖于模型参数的外部噪声源(通常是标准正态分布中抽取的),并对这个噪声进行变换(使用模型参数如均值和方差),来生成符合目标分布的样本。这样,模型的随机输出就可以表示为模型参数的确定性函数和一个随机噪声的组合。便可以完成梯度回传

    2、VAE的代码实现

    整体架构,VAE计算以下两方面之间的损失:

    1. 重构损失(Reconstruction Loss):这一部分的损失计算的是输入数据与重构数据之间的差异。

    2. KL散度(Kullback-Leibler Divergence Loss):这一部分的损失衡量的是学习到的潜在表示的分布与先验分布(通常假设为标准正态分布)之间的差异。KL散度是一种衡量两个概率分布相似度的指标,VAE通过最小化KL散度来确保学习到的潜在表示的分布尽可能接近先验分布。这有助于模型生成性能的提升,因为它约束了潜在空间的结构,使其更加规整,便于采样和推断。

    1)Encoder

    image --> 均值 + 标准差

    1. import torch
    2. from torch import nn
    3. from torch.nn import functional as F
    4. # Encoder class definition
    5. class Encoder(nn.Module):
    6. def __init__(self, input_dim, hidden_dim, latent_dim):
    7. super(Encoder, self).__init__()
    8. # 使用FC将输入变为隐藏层hidden_dim
    9. self.fc1 = nn.Linear(input_dim, hidden_dim)
    10. # Two fully connected layers to produce mean and log variance
    11. # These will represent the latent space distribution parameters
    12. self.fc21 = nn.Linear(hidden_dim, latent_dim) # 隐藏层hidden_dim --> 均值Mean μ
    13. self.fc22 = nn.Linear(hidden_dim, latent_dim) # 隐藏层hidden_dim --> 标准差Log variance σ
    14. def forward(self, x):
    15. # 使用RELU非线性变换,增加网络的表达能力
    16. h1 = F.relu(self.fc1(x))
    17. # Return the mean and log variance for the latent space
    18. return self.fc21(h1), self.fc22(h1)

    2)Decoder

    1. # Decoder class definition
    2. class Decoder(nn.Module):
    3. def __init__(self, latent_dim, hidden_dim, output_dim):
    4. super(Decoder, self).__init__()
    5. # latent_dim --> hidden_dim
    6. self.fc3 = nn.Linear(latent_dim, hidden_dim)
    7. # hidden_dim --> output_dim(输出的图像)
    8. self.fc4 = nn.Linear(hidden_dim, output_dim)
    9. def forward(self, z):
    10. h3 = F.relu(self.fc3(z))
    11. return torch.sigmoid(self.fc4(h3))

    3)VAE

    重参数化技巧:

    这段代码对应的数学公式可以写作:

    z= σ⋅ϵ + μ

    其中:

    • z 是从潜在分布中采样得到的样本。
    • log var(log方差):  log(σ2)
    • μ 是潜在分布的均值,对应代码中的 mu
    • σ 是潜在分布的标准差 std ,通过 torch.exp(0.5*logvar) 计算得到,这里 logvar 是对数方差log(σ2),因此 σ=exp(0.5*​log(σ2))。
    • ϵ 是从标准正态分布 N(0,1) 中采样得到的随机噪声,对应 torch.randn_like(std)

    代码中的 eps.mul(std).add_(mu) 实现了上述公式的计算,即首先将随机噪声 ϵ 与标准差 σ 相乘,然后将结果加上均值 μ。这样,得到的 z 既包含了模型学习到的分布的特征(通过 μ 和 σ),同时也引入了必要的随机性(通过 ϵ),允许模型通过采样生成多样化的数据。

    1. # VAE class definition
    2. # Encode the input --> reparameterize --> decode
    3. class VAE(nn.Module):
    4. def __init__(self, input_dim, hidden_dim, latent_dim):
    5. super(VAE, self).__init__()
    6. self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
    7. self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
    8. def reparameterize(self, mu, logvar):
    9. # Reparameterization trick to sample from the distribution represented by the mean and log variance
    10. std = torch.exp(0.5*logvar)
    11. eps = torch.randn_like(std)
    12. return eps.mul(std).add_(mu)
    13. def forward(self, x):
    14. mu, logvar = self.encoder(x.view(-1, input_dim))
    15. z = self.reparameterize(mu, logvar)
    16. return self.decoder(z), mu, logvar

    4)Loss

    1. # Loss function for VAE
    2. def vae_loss_function(recon_x, x, mu, logvar):
    3. # Binary cross entropy between the target and the output
    4. BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
    5. # KL divergence loss : 学习到的潜在表示的分布 <--> 先验分布(标准正态分布)
    6. KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    7. return BCE + KLD

    5)训练过程

    1. # Hyperparameters
    2. input_dim = 784 # Assuming input is a flattened 28x28 image (e.g., from MNIST)
    3. hidden_dim = 400
    4. latent_dim = 20
    5. epochs = 10
    6. learning_rate = 1e-3
    7. # Initialize VAE
    8. vae = VAE(input_dim, hidden_dim, latent_dim)
    9. optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
    10. # Training process function
    11. for epoch in range(epochs):
    12. vae.train() # Set the model to training mode
    13. train_loss = 0
    14. for batch_idx, (data, _) in enumerate(data_loader):
    15. optimizer.zero_grad() # Zero the gradients
    16. recon_batch, mu, logvar = vae(data) # Forward pass through VAE
    17. loss = vae_loss_function(recon_batch, data, mu, logvar) # Compute the loss
    18. loss.backward() # Backpropagate the loss
    19. train_loss += loss.item()
    20. optimizer.step()

    三、VQ-VAE

    之所以叫做VQ(向量量化),主要是因为将连续潜在空间的点映射到最近的一组离散的向量(即码本中的向量)上

    VQ-VAE的全称是Vector Quantized-Variational AutoEncoder,即向量量化变分自编码器。这是一种结合了变分自编码器(VAE)和向量量化(VQ)的深度学习模型,主要用于高效地学习数据的潜在表示。VQ-VAE通过将连续的潜在表示空间离散化来改进传统VAE模型。向量量化的过程实质上是将连续潜在空间的点映射到最近的一组离散的向量(即码本中的向量)上,这有助于模型捕捉和表示更加丰富和复杂的数据分布,由于维护了一个codebook,编码范围更加可控,VQVAE相对于VAE(VAE的隐变量 z 的每一维都是一个连续的值, 而VQ-VAE最大的特点就是, z 的每一维都是离散的整数。),可以生成更大更高清的图片(这也为后续DALLE和VQGAN的出现做了铺垫)。【原文中说的是避免了“后验坍塌”的问题】

    1、算法步骤:

    1. 通过Encoder学习出中间编码 Ze(x)【绿色】
    2. 事先定义好codebook,它有N个e组成【紫色】
    3. 然后通过最邻近搜索与中间编码Ze(x)最相似(接近)的codebook中K个向量之一,并记住这个向量的index【青色】
    4. 根据得到的所有index去映射对应的codebook中的vector,得到输入图像对应的特征表征Zq(x)【紫色】
    5. 然后通过Decoder对Zq(x)进行重建

    另外由于最邻近搜索使用argmax来找codebook中的索引位置,导致不可导问题,VQVAE通过stop gradient操作来避免最邻近搜索的不可导问题,也就是通过stop gradient操作,将decoder输入的梯度复制到encoder的输出上【红色的线】。

    2、一些问题

    A. 为什么要进行向量量化?(为什么要将 z 离散化)?

    1. 离散的表示通常更适合于捕捉数据中的类别性质,如不同种类的对象、语音或文本数据的不同模式等。

    2. 离散的潜在表示有助于模型生成更加清晰的输出。在连续潜在空间中,模型可能在生成新样本时产生模糊的结果,特别是在空间的某些区域中。而离散潜在空间能够降低这种模糊性,提高生成样本的质量。

    3. 增强模型的解释性,相比于连续潜在空间,离散潜在空间可以为每个离散的潜在变量赋予更明确的语义解释。例如,在处理图像的任务中,不同的离散潜在变量可能对应于不同的视觉特征或对象类别,这使得模型的行为和学习到的表示更易于理解和解释。

    5. 缓解潜在空间的过度平滑问题,VAE有时会遇到潜在空间的"过度平滑"问题,即潜在空间中不同区域之间的过渡太平滑,导致生成的样本缺乏多样性或区分度不够(容易模型崩塌)。通过引入离散潜在空间,VQ-VAE可以缓解这个问题,因为离散空间天然具有区分不同区域的能力。

    B. 如何将 z 离散化?

    1)构建codebook进行VQ 

    将 z 离散化的关键就是VQ, 即vector quatization.

    简单来说, 就是要先有一个codebook, 这个codebook是一个embedding table,然后再利用均匀分布对权重初始化

    1. self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
    2. self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)

    2)找到与图像embedding(flat_input)最近的codebook中的embedding 

    我们在这个codebook 中找到和 vector 最接近(比如欧氏距离最近)的一个embedding, 用这个embedding的index来代表这个vector.

    ∑ [flat_input(16384, 64) - self._embedding.weight(512,64)]^2

    1. # Calculate the Z_e(x) and e distances
    2. # 这里使用欧几里得距离的平方求距离
    3. distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
    4. + torch.sum(self._embedding.weight ** 2, dim=1)
    5. - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
    6. # Encoding
    7. encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)

    3)怎么解决的梯度截断(不可导)问题?

    另外由于最邻近搜索使用argmin来找codebook中的索引位置,导致不可导问题,VQVAE通过stop gradient操作来避免最邻近搜索的不可导问题,也就是通过stop gradient操作,将decoder输入(quantize)的梯度复制到encoder的输出上(input)。

    1. quantize = input + (quantize - input).detach()
    2. # 正向传播和往常一样
    3. # 反向传播时,detach()这部分梯度为0,quantize和input的梯度相同
    4. # 即实现将quantize复制给input

    VQVAE相比于VAE最大的不同是,直接找每个属性的离散值,通过类似于查表的

    直接按照下面这样,而不进行 (quantized - inputs).detach() 是不可以的,因为这样会让模型退化到一个VAE的样子,因为你是直接将输出复制给输入的。具体没搞清楚,反正人家是这么干的...

    quantize = input

    3、VQVAE的损失

    与VAE的不同:去掉VAE的KL loss,增加了两项loss

    1)重构损失(reconstruction loss):

    • 目标:衡量重构数据和原始数据之间的相似度。
    • 计算:通常使用均方误差(MSE)或交叉熵损失来计算重构图像和原始图像之间的差异。

    2)代码本损失(codebook loss):

    代码本损失关注于更新码本向量,使其更好地代表输入数据的连续潜在表示。

    假设Ze​(x)是编码器对输入x的连续潜在表示,e是选取的最接近的码本向量。

    3)提交损失(Commitment Loss)

    提交损失则确保编码器的输出不会偏离它选择的码本向量太远,从而保证训练过程的稳定性。

    计算模型编码器输出的连续潜在表示和量化后的表示(即选取的码本向量)之间的距离。这有助于稳定训练,确保编码器输出与码本的选择保持一致,也通常使用均方误差(MSE)计算。

    假设Ze​(x)是编码器对输入x的连续潜在表示,e是选取的最接近的码本向量。

    4、代码示例

    需要确保经过encoder的图像的通道数D(此时经过encoder后的图像“通道数”不一定再是3了,可能会更大,例如64,我们这里只是形象的把它叫做“通道数”罢了)与codebook中的向量维度是相同的。

    整体步骤

    具体实现步骤如下:

    完整代码如下:

    1. from __future__ import print_function
    2. import matplotlib.pyplot as plt
    3. import numpy as np
    4. from scipy.signal import savgol_filter
    5. from six.moves import xrange
    6. import umap
    7. import torch
    8. import torch.nn as nn
    9. import torch.nn.functional as F
    10. from torch.utils.data import DataLoader
    11. import torch.optim as optim
    12. import torchvision.datasets as datasets
    13. import torchvision.transforms as transforms
    14. from torchvision.utils import make_grid
    15. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    16. # ===================================================== Load Data =====================================================
    17. training_data = datasets.CIFAR10(root="data", train=True, download=True,
    18. transform=transforms.Compose([
    19. transforms.ToTensor(),
    20. transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
    21. ]))
    22. validation_data = datasets.CIFAR10(root="data", train=False, download=True,
    23. transform=transforms.Compose([
    24. transforms.ToTensor(),
    25. transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
    26. ]))
    27. # 计算 方差
    28. data_variance = np.var(training_data.data / 255.0)
    29. # ===================================================== Vector Quantizer Layer =====================================================
    30. """
    31. This layer takes a tensor to be quantized.
    32. The channel dimension will be used as the space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.
    33. The output tensor will have the same shape as the input.
    34. As an example for a BCHW tensor of shape [16, 64, 32, 32], we will first convert it to an BHWC tensor of shape [16, 32, 32, 64] and then reshape it into [16384, 64] and all 16384 vectors of size 64 will be quantized independently.
    35. In otherwords, the channels are used as the space in which to quantize. All other dimensions will be flattened and be seen as different examples to quantize, 16384 in this case.
    36. """
    37. class VectorQuantizer(nn.Module):
    38. def __init__(self, num_embeddings, embedding_dim, commitment_cost):
    39. '''
    40. :param num_embeddings: codebook的大小
    41. :param embedding_dim: codebook中每个vector的维度
    42. :param commitment_cost: commit loss的β
    43. '''
    44. super(VectorQuantizer, self).__init__()
    45. self._embedding_dim = embedding_dim
    46. self._num_embeddings = num_embeddings
    47. # 构建一个codebook,用均匀分布对codebook的权重进行初始化
    48. self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
    49. self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
    50. self._commitment_cost = commitment_cost
    51. def forward(self, inputs):
    52. # convert inputs(encoder's output) from BCHW -> BHWC
    53. inputs = inputs.permute(0, 2, 3, 1).contiguous()
    54. input_shape = inputs.shape
    55. # Flatten input
    56. flat_input = inputs.view(-1, self._embedding_dim)
    57. # Calculate the Z_e(x) and e distances
    58. # 这里使用欧几里得距离的平方求距离distance(16384,512)
    59. distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
    60. + torch.sum(self._embedding.weight ** 2, dim=1)
    61. - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
    62. # 通过distance得到距离最近的index
    63. encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
    64. # 转为one-hot格式
    65. encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
    66. encodings.scatter_(1, encoding_indices, 1)
    67. # Quantize and unflatten:得到最近邻的Embedding Vector
    68. quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
    69. # Loss
    70. # commit loss
    71. e_latent_loss = F.mse_loss(quantized.detach(), inputs)
    72. # codebook loss
    73. q_latent_loss = F.mse_loss(quantized, inputs.detach())
    74. loss = q_latent_loss + self._commitment_cost * e_latent_loss
    75. # trick(梯度复制),通过添加一个常数让编码器和解码器连续可导
    76. quantized = inputs + (quantized - inputs).detach()
    77. # 利用困惑度监测分布,困惑度越大,信息熵也就越大,分布就没有这么均匀
    78. avg_probs = torch.mean(encodings, dim=0)
    79. perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
    80. # convert quantized from BHWC -> BCHW
    81. return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
    82. """
    83. We will also implement a slightly modified version which will use exponential moving averages to update the embedding vectors instead of an auxillary loss.
    84. This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other parts of the architecture.
    85. For most experiments the EMA version trains faster than the non-EMA version.
    86. """
    87. class VectorQuantizerEMA(nn.Module):
    88. def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
    89. super(VectorQuantizerEMA, self).__init__()
    90. self._embedding_dim = embedding_dim
    91. self._num_embeddings = num_embeddings
    92. self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
    93. self._embedding.weight.data.normal_()
    94. self._commitment_cost = commitment_cost
    95. self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
    96. self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
    97. self._ema_w.data.normal_()
    98. self._decay = decay
    99. self._epsilon = epsilon
    100. def forward(self, inputs):
    101. # convert inputs from BCHW -> BHWC
    102. inputs = inputs.permute(0, 2, 3, 1).contiguous()
    103. input_shape = inputs.shape
    104. # Flatten input
    105. flat_input = inputs.view(-1, self._embedding_dim)
    106. # Calculate distances
    107. distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
    108. + torch.sum(self._embedding.weight ** 2, dim=1)
    109. - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
    110. # Encoding
    111. encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
    112. encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
    113. encodings.scatter_(1, encoding_indices, 1)
    114. # Quantize and unflatten
    115. quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
    116. # Use EMA to update the embedding vectors
    117. if self.training:
    118. self._ema_cluster_size = self._ema_cluster_size * self._decay + \
    119. (1 - self._decay) * torch.sum(encodings, 0)
    120. # Laplace smoothing of the cluster size
    121. n = torch.sum(self._ema_cluster_size.data)
    122. self._ema_cluster_size = (
    123. (self._ema_cluster_size + self._epsilon)
    124. / (n + self._num_embeddings * self._epsilon) * n)
    125. dw = torch.matmul(encodings.t(), flat_input)
    126. self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
    127. self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
    128. # Loss
    129. e_latent_loss = F.mse_loss(quantized.detach(), inputs)
    130. loss = self._commitment_cost * e_latent_loss
    131. # Straight Through Estimator
    132. quantized = inputs + (quantized - inputs).detach()
    133. avg_probs = torch.mean(encodings, dim=0)
    134. perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
    135. # convert quantized from BHWC -> BCHW
    136. return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
    137. # ===================================================== Encoder & Decoder Architecture =====================================================
    138. class Residual(nn.Module):
    139. def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
    140. super(Residual, self).__init__()
    141. self._block = nn.Sequential(
    142. nn.ReLU(True),
    143. nn.Conv2d(in_channels=in_channels,
    144. out_channels=num_residual_hiddens,
    145. kernel_size=3, stride=1, padding=1, bias=False),
    146. nn.ReLU(True),
    147. nn.Conv2d(in_channels=num_residual_hiddens,
    148. out_channels=num_hiddens,
    149. kernel_size=1, stride=1, bias=False)
    150. )
    151. def forward(self, x):
    152. return x + self._block(x)
    153. class ResidualStack(nn.Module):
    154. def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
    155. super(ResidualStack, self).__init__()
    156. self._num_residual_layers = num_residual_layers
    157. self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
    158. for _ in range(self._num_residual_layers)])
    159. def forward(self, x):
    160. for i in range(self._num_residual_layers):
    161. x = self._layers[i](x)
    162. return F.relu(x)
    163. class Encoder(nn.Module):
    164. def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
    165. super(Encoder, self).__init__()
    166. self._conv_1 = nn.Conv2d(in_channels=in_channels,
    167. out_channels=num_hiddens // 2,
    168. kernel_size=4,
    169. stride=2, padding=1)
    170. self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2,
    171. out_channels=num_hiddens,
    172. kernel_size=4,
    173. stride=2, padding=1)
    174. self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
    175. out_channels=num_hiddens,
    176. kernel_size=3,
    177. stride=1, padding=1)
    178. self._residual_stack = ResidualStack(in_channels=num_hiddens,
    179. num_hiddens=num_hiddens,
    180. num_residual_layers=num_residual_layers,
    181. num_residual_hiddens=num_residual_hiddens)
    182. def forward(self, inputs):
    183. x = self._conv_1(inputs)
    184. x = F.relu(x)
    185. x = self._conv_2(x)
    186. x = F.relu(x)
    187. x = self._conv_3(x)
    188. return self._residual_stack(x)
    189. class Decoder(nn.Module):
    190. def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
    191. super(Decoder, self).__init__()
    192. self._conv_1 = nn.Conv2d(in_channels=in_channels,
    193. out_channels=num_hiddens,
    194. kernel_size=3,
    195. stride=1, padding=1)
    196. self._residual_stack = ResidualStack(in_channels=num_hiddens,
    197. num_hiddens=num_hiddens,
    198. num_residual_layers=num_residual_layers,
    199. num_residual_hiddens=num_residual_hiddens)
    200. self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
    201. out_channels=num_hiddens // 2,
    202. kernel_size=4,
    203. stride=2, padding=1)
    204. self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
    205. out_channels=3,
    206. kernel_size=4,
    207. stride=2, padding=1)
    208. def forward(self, inputs):
    209. x = self._conv_1(inputs)
    210. x = self._residual_stack(x)
    211. x = self._conv_trans_1(x)
    212. x = F.relu(x)
    213. return self._conv_trans_2(x)
    214. # ===================================================== Train =====================================================
    215. batch_size = 256
    216. num_training_updates = 15000
    217. num_hiddens = 128
    218. num_residual_hiddens = 32
    219. num_residual_layers = 2
    220. # codebook的维度
    221. embedding_dim = 64
    222. num_embeddings = 512
    223. commitment_cost = 0.25
    224. decay = 0.99
    225. decay = 0
    226. learning_rate = 1e-3
    227. training_loader = DataLoader(training_data,
    228. batch_size=batch_size,
    229. shuffle=True,
    230. pin_memory=True)
    231. validation_loader = DataLoader(validation_data,
    232. batch_size=32,
    233. shuffle=True,
    234. pin_memory=True)
    235. class Model(nn.Module):
    236. def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay=0):
    237. super(Model, self).__init__()
    238. self._encoder = Encoder(3, num_hiddens, num_residual_layers, num_residual_hiddens)
    239. # 对encoder的输出进行后处理,得到与embedding table一样大小的维度
    240. self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
    241. out_channels=embedding_dim,
    242. kernel_size=1,
    243. stride=1)
    244. if decay > 0.0:
    245. self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
    246. else:
    247. self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
    248. self._decoder = Decoder(embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
    249. def forward(self, x):
    250. z = self._encoder(x)
    251. z = self._pre_vq_conv(z)
    252. loss, quantized, perplexity, _ = self._vq_vae(z)
    253. x_recon = self._decoder(quantized)
    254. return loss, x_recon, perplexity
    255. model = Model(num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay).to(device)
    256. optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)
    257. model.train()
    258. train_res_recon_error = []
    259. train_res_perplexity = []
    260. for i in xrange(num_training_updates):
    261. (data, _) = next(iter(training_loader))
    262. data = data.to(device)
    263. optimizer.zero_grad()
    264. vq_loss, data_recon, perplexity = model(data)
    265. recon_error = F.mse_loss(data_recon, data) / data_variance
    266. loss = recon_error + vq_loss
    267. loss.backward()
    268. optimizer.step()
    269. train_res_recon_error.append(recon_error.item())
    270. train_res_perplexity.append(perplexity.item())
    271. if (i + 1) % 100 == 0:
    272. print('%d iterations' % (i + 1))
    273. print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
    274. print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
    275. print()

    pytorch-vq-vae/vq-vae.ipynb at master · zalandoresearch/pytorch-vq-vae · GitHub

    四、总结

    • AE主要用于数据的压缩与还原,在生成数据上使用VAE。
    • AE是将数据映直接映射为数值code,而VAE是先将数据映射为分布,再从分布中采样得到数值code。
    • VQ-VAE是将中间编码映射为codebook中K个向量之一,然后通过Decoder对latent code进行重建

    因此AutoEncoder、VAE和VQ-VAE可以统一为latent code的概率分布设计不一样,AutoEncoder通过网络学习得到任意概率分布VAE设计为正态分布VQVAE设计为codebook的离散分布总之,AutoEncoder的重构思想就是用低纬度的latent code分布来表达高纬度的数据分布,VAE和VQVAE的重构思想是通过设计latent code的分布形式,进而控制图片生成的过程。

    漫谈VAE和VQVAE,从连续分布到离散分布 - 知乎

  • 相关阅读:
    数据通讯基础
    Gateway路由的配置方式
    (八)RabbitMQ发布确认
    Servlet | 域对象、request对象其它常用的方法
    3如何搭建组件库的样式工程之button-scss
    如何实现 System.out.println("a") 显示 b
    怎么在国家级媒体网站投稿发布通稿
    Linux目录结构
    如何快速高效压缩图片?
    AWS SAP-C02 考试指南
  • 原文地址:https://blog.csdn.net/weixin_43135178/article/details/130592568