他们各自的概念看以下链接就可以了:https://blog.csdn.net/weixin_43135178/category_11543123.html
这里主要谈一下他们的区别?
如上图所示,自动编码器主要由两部分组成:编码器(Encoder)和解码器(Decoder)。编码器和解码器可以看作是两个函数,一个用于将高维输入(如图片)映射为低维编码(code),另一个用于将低维编码(code)映射为高维输出(如生成的图片)。这两个函数可以是任意形式,但在深度学习中,我们用神经网络去学习这两个函数。
这时候我们只要拿出Decoder部分,随机生成一个code然后输入,就可以得到一张生成的图像。但实际上这样的生成效果并不好(下面解释原因),因此AE多用于数据压缩,而数据生成则使用下面所介绍的VAE更好。
由上面介绍可以看出,AE的Encoder是将图片映射成“数值编码”,Decoder是将“数值编码”映射成图片。这样存在的问题是,在训练过程中,随着不断降低输入图片与输出图片之间的误差,模型会过拟合,泛化性能不好。也就是说对于一个训练好的AE,输入某个图片,就只会将其编码为某个确定的code,输入某个确定的code就只会输出某个确定的图片,如果这个latent code来自于没见过的图片,那么生成的图片也不会好。下面举个例子来说明:
假设我们训练好的AE将“新月”图片encode成code=1(这里假设code只有1维),将其decode能得到“新月”的图片;将“满月”encode成code=10,同样将其decode能得到“满月”图片。这时候如果我们给AE一个code=5,我们希望是能得到“半月”的图片,但由于之前训练时并没有将“半月”的图片编码,或者将一张非月亮的图片编码为5,那么我们就不太可能得到“半月”的图片。因此AE多用于数据的压缩和恢复,用于数据生成时效果并不理想。
- import torch
- from torch import nn
- from torch.autograd import Variable
-
-
- # Define the encoder and decoder networks
- class Encoder(nn.Module):
- def __init__(self, input_dim, latent_dim):
- super(Encoder, self).__init__()
- self.fc1 = nn.Linear(input_dim, latent_dim)
-
- def forward(self, x):
- x = torch.relu(self.fc1(x))
- return x
-
-
- class Decoder(nn.Module):
- def __init__(self, latent_dim, output_dim):
- super(Decoder, self).__init__()
- self.fc1 = nn.Linear(latent_dim, output_dim)
-
- def forward(self, x):
- x = torch.sigmoid(self.fc1(x))
- return x
-
-
- class Autoencoder(nn.Module):
- def __init__(self, input_dim, latent_dim):
- super(Autoencoder, self).__init__()
- self.encoder = Encoder(input_dim, latent_dim)
- self.decoder = Decoder(latent_dim, input_dim)
-
- def forward(self, x):
- x = self.encoder(x)
- x = self.decoder(x)
- return x
- # Define the model, loss function, and optimizer
- input_dim = 784 # For MNIST images
- latent_dim = 32
- model = Autoencoder(input_dim, latent_dim)
- criterion = nn.BCELoss()
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
- # Training loop
- for epoch in range(num_epochs):
- for data in dataloader:
- img, _ = data
- img = img.view(img.size(0), -1)
- img = Variable(img)
-
- # Forward pass
- output = model(img)
- loss = criterion(output, img)
-
- # Backward pass and optimization
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
- # Evaluate the model
- model.eval()
- with torch.no_grad():
- for data in dataloader:
- img, _ = data
- img = img.view(img.size(0), -1)
- img = Variable(img)
- output = model(img)
- # Initialize the autoencoder
- input_dim = 784 # Assuming input is a flattened 28x28 image
- latent_dim = 20
- autoencoder = Autoencoder(input_dim, latent_dim)
- latent_vector = torch.randn(latent_dim)
-
- # Generate a new image by passing the latent vector through the decoder
- with torch.no_grad():
- generated_image = autoencoder.decoder(latent_vector)
这时候我们转变思路,不将图片映射成“数值编码”,而将其映射成“分布”。还是刚刚的例子,我们将“新月”图片映射成μ=1的正态分布,那么就相当于在1附近加了噪声,此时不仅1表示“新月”,1附近的数值也表示“新月”,只是1的时候最像“新月”。将"满月"映射成μ=10的正态分布,10的附近也都表示“满月”。那么code=5时,就同时拥有了“新月”和“满月”的特点,那么这时候decode出来的大概率就是“半月”了。这就是VAE的思想。
vae和常规自编码器不同的地方在于它的encoder的output不是一个latent vector让decoder去做decode,而是从某个连续的分布里(常见的是高斯分布)采样得到一个随机数or随机向量,然后decode再去针对这个scaler做解码。
不使用连续的AE这种分布,而是使用符合某种分布的VAE,这是因为:
常规的ae的潜在空间的规律性是一个难点,它取决于初始空间中数据的分布、潜在空间的维度和编码器的架构等。因此,我们基本不可能就认为ae的latent vector的distribution和我们产生随机数的distribution是一个distribution(不可能),那就很尴尬了,假设latent vector的取值范围都在0~1之间,然后我们产生了一个包含了大量的负数的随机数让decoder去decode,那么decoder压根就decode不出什么正常的东西,training阶段压根就没见过嘛。想一想,潜在空间中的latent vector 的分布规律很难知道也是正常的,因为常规的自动编码器的任务中没有任何东西被训练来强制获得这样的规律(但是vae就会假设latent vector服从高斯分布):自动编码器被训练成以尽可能少的损失进行编码和解码,压根就不care latent vector服从什么分布。那么自然,我们是不可能使用一个预定义的随机分布产生随机的input然后又期望decoder能够decode出有意义的东西的.
既然我们不知道latent vector服从什么分布,我们就直接人为对其进行约束满足某种预定义的分布,这个预定义的分布和我们产生随机数的分布保持一致,不就完美解决问题了吗?
所以通过VAE求出均值和方差,然后使用重参数化技巧在得到的这个分布中进行采样,就可以得到符合此分布的latent vector了。
具体来说,在不使用重参数化的情况下,模型会直接从参数化的分布(例如,正态分布,由均值 μ 和方差 σ2 参数化)中采样,这使得梯度无法直接通过采样过程回传。重参数化技巧通过引入一个不依赖于模型参数的外部噪声源(通常是标准正态分布中抽取的),并对这个噪声进行变换(使用模型参数如均值和方差),来生成符合目标分布的样本。这样,模型的随机输出就可以表示为模型参数的确定性函数和一个随机噪声的组合。便可以完成梯度回传
整体架构,VAE计算以下两方面之间的损失:
重构损失(Reconstruction Loss):这一部分的损失计算的是输入数据与重构数据之间的差异。
KL散度(Kullback-Leibler Divergence Loss):这一部分的损失衡量的是学习到的潜在表示的分布与先验分布(通常假设为标准正态分布)之间的差异。KL散度是一种衡量两个概率分布相似度的指标,VAE通过最小化KL散度来确保学习到的潜在表示的分布尽可能接近先验分布。这有助于模型生成性能的提升,因为它约束了潜在空间的结构,使其更加规整,便于采样和推断。
image --> 均值 + 标准差
- import torch
- from torch import nn
- from torch.nn import functional as F
-
- # Encoder class definition
- class Encoder(nn.Module):
- def __init__(self, input_dim, hidden_dim, latent_dim):
- super(Encoder, self).__init__()
- # 使用FC将输入变为隐藏层hidden_dim
- self.fc1 = nn.Linear(input_dim, hidden_dim)
- # Two fully connected layers to produce mean and log variance
- # These will represent the latent space distribution parameters
- self.fc21 = nn.Linear(hidden_dim, latent_dim) # 隐藏层hidden_dim --> 均值Mean μ
- self.fc22 = nn.Linear(hidden_dim, latent_dim) # 隐藏层hidden_dim --> 标准差Log variance σ
-
- def forward(self, x):
- # 使用RELU非线性变换,增加网络的表达能力
- h1 = F.relu(self.fc1(x))
- # Return the mean and log variance for the latent space
- return self.fc21(h1), self.fc22(h1)
- # Decoder class definition
- class Decoder(nn.Module):
- def __init__(self, latent_dim, hidden_dim, output_dim):
- super(Decoder, self).__init__()
- # latent_dim --> hidden_dim
- self.fc3 = nn.Linear(latent_dim, hidden_dim)
- # hidden_dim --> output_dim(输出的图像)
- self.fc4 = nn.Linear(hidden_dim, output_dim)
-
- def forward(self, z):
- h3 = F.relu(self.fc3(z))
- return torch.sigmoid(self.fc4(h3))
重参数化技巧:
这段代码对应的数学公式可以写作:
z= σ⋅ϵ + μ
其中:
mu
。torch.exp(0.5*logvar)
计算得到,这里 logvar
是对数方差log(σ2),因此 σ=exp(0.5*log(σ2))。torch.randn_like(std)
。代码中的 eps.mul(std).add_(mu)
实现了上述公式的计算,即首先将随机噪声 ϵ 与标准差 σ 相乘,然后将结果加上均值 μ。这样,得到的 z 既包含了模型学习到的分布的特征(通过 μ 和 σ),同时也引入了必要的随机性(通过 ϵ),允许模型通过采样生成多样化的数据。
- # VAE class definition
- # Encode the input --> reparameterize --> decode
- class VAE(nn.Module):
- def __init__(self, input_dim, hidden_dim, latent_dim):
- super(VAE, self).__init__()
- self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
- self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
-
- def reparameterize(self, mu, logvar):
- # Reparameterization trick to sample from the distribution represented by the mean and log variance
- std = torch.exp(0.5*logvar)
- eps = torch.randn_like(std)
- return eps.mul(std).add_(mu)
-
- def forward(self, x):
- mu, logvar = self.encoder(x.view(-1, input_dim))
- z = self.reparameterize(mu, logvar)
- return self.decoder(z), mu, logvar
- # Loss function for VAE
- def vae_loss_function(recon_x, x, mu, logvar):
- # Binary cross entropy between the target and the output
- BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
- # KL divergence loss : 学习到的潜在表示的分布 <--> 先验分布(标准正态分布)
- KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
- return BCE + KLD
- # Hyperparameters
- input_dim = 784 # Assuming input is a flattened 28x28 image (e.g., from MNIST)
- hidden_dim = 400
- latent_dim = 20
- epochs = 10
- learning_rate = 1e-3
-
- # Initialize VAE
- vae = VAE(input_dim, hidden_dim, latent_dim)
- optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
-
- # Training process function
- for epoch in range(epochs):
- vae.train() # Set the model to training mode
- train_loss = 0
- for batch_idx, (data, _) in enumerate(data_loader):
- optimizer.zero_grad() # Zero the gradients
- recon_batch, mu, logvar = vae(data) # Forward pass through VAE
- loss = vae_loss_function(recon_batch, data, mu, logvar) # Compute the loss
- loss.backward() # Backpropagate the loss
- train_loss += loss.item()
- optimizer.step()
之所以叫做VQ(向量量化),主要是因为将连续潜在空间的点映射到最近的一组离散的向量(即码本中的向量)上
VQ-VAE的全称是Vector Quantized-Variational AutoEncoder,即向量量化变分自编码器。这是一种结合了变分自编码器(VAE)和向量量化(VQ)的深度学习模型,主要用于高效地学习数据的潜在表示。VQ-VAE通过将连续的潜在表示空间离散化来改进传统VAE模型。向量量化的过程实质上是将连续潜在空间的点映射到最近的一组离散的向量(即码本中的向量)上,这有助于模型捕捉和表示更加丰富和复杂的数据分布,由于维护了一个codebook,编码范围更加可控,VQVAE相对于VAE(VAE的隐变量 z 的每一维都是一个连续的值, 而VQ-VAE最大的特点就是, z 的每一维都是离散的整数。),可以生成更大更高清的图片(这也为后续DALLE和VQGAN的出现做了铺垫)。【原文中说的是避免了“后验坍塌”的问题】
另外由于最邻近搜索使用argmax来找codebook中的索引位置,导致不可导问题,VQVAE通过stop gradient操作来避免最邻近搜索的不可导问题,也就是通过stop gradient操作,将decoder输入的梯度复制到encoder的输出上【红色的线】。
1. 离散的表示通常更适合于捕捉数据中的类别性质,如不同种类的对象、语音或文本数据的不同模式等。
2. 离散的潜在表示有助于模型生成更加清晰的输出。在连续潜在空间中,模型可能在生成新样本时产生模糊的结果,特别是在空间的某些区域中。而离散潜在空间能够降低这种模糊性,提高生成样本的质量。
3. 增强模型的解释性,相比于连续潜在空间,离散潜在空间可以为每个离散的潜在变量赋予更明确的语义解释。例如,在处理图像的任务中,不同的离散潜在变量可能对应于不同的视觉特征或对象类别,这使得模型的行为和学习到的表示更易于理解和解释。
5. 缓解潜在空间的过度平滑问题,VAE有时会遇到潜在空间的"过度平滑"问题,即潜在空间中不同区域之间的过渡太平滑,导致生成的样本缺乏多样性或区分度不够(容易模型崩塌)。通过引入离散潜在空间,VQ-VAE可以缓解这个问题,因为离散空间天然具有区分不同区域的能力。
1)构建codebook进行VQ
将 z 离散化的关键就是VQ, 即vector quatization.
简单来说, 就是要先有一个codebook, 这个codebook是一个embedding table,然后再利用均匀分布对权重初始化
- self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
- self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
2)找到与图像embedding(flat_input)最近的codebook中的embedding
我们在这个codebook 中找到和 vector 最接近(比如欧氏距离最近)的一个embedding, 用这个embedding的index来代表这个vector.
∑ [flat_input(16384, 64) - self._embedding.weight(512,64)]^2
- # Calculate the Z_e(x) and e distances
- # 这里使用欧几里得距离的平方求距离
- distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
- + torch.sum(self._embedding.weight ** 2, dim=1)
- - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
-
- # Encoding
- encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
3)怎么解决的梯度截断(不可导)问题?
另外由于最邻近搜索使用argmin来找codebook中的索引位置,导致不可导问题,VQVAE通过stop gradient操作来避免最邻近搜索的不可导问题,也就是通过stop gradient操作,将decoder输入(quantize)的梯度复制到encoder的输出上(input)。
- quantize = input + (quantize - input).detach()
- # 正向传播和往常一样
-
- # 反向传播时,detach()这部分梯度为0,quantize和input的梯度相同
- # 即实现将quantize复制给input
VQVAE相比于VAE最大的不同是,直接找每个属性的离散值,通过类似于查表的
直接按照下面这样,而不进行 (quantized - inputs).detach() 是不可以的,因为这样会让模型退化到一个VAE的样子,因为你是直接将输出复制给输入的。具体没搞清楚,反正人家是这么干的...
quantize = input
与VAE的不同:去掉VAE的KL loss,增加了两项loss
1)重构损失(reconstruction loss):
2)代码本损失(codebook loss):
代码本损失关注于更新码本向量,使其更好地代表输入数据的连续潜在表示。
假设Ze(x)是编码器对输入x的连续潜在表示,e是选取的最接近的码本向量。
3)提交损失(Commitment Loss):
提交损失则确保编码器的输出不会偏离它选择的码本向量太远,从而保证训练过程的稳定性。
计算模型编码器输出的连续潜在表示和量化后的表示(即选取的码本向量)之间的距离。这有助于稳定训练,确保编码器输出与码本的选择保持一致,也通常使用均方误差(MSE)计算。
假设Ze(x)是编码器对输入x的连续潜在表示,e是选取的最接近的码本向量。
需要确保经过encoder的图像的通道数D(此时经过encoder后的图像“通道数”不一定再是3了,可能会更大,例如64,我们这里只是形象的把它叫做“通道数”罢了)与codebook中的向量维度是相同的。
具体实现步骤如下:
完整代码如下:
- from __future__ import print_function
-
-
- import matplotlib.pyplot as plt
- import numpy as np
- from scipy.signal import savgol_filter
-
-
- from six.moves import xrange
-
- import umap
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils.data import DataLoader
- import torch.optim as optim
-
- import torchvision.datasets as datasets
- import torchvision.transforms as transforms
- from torchvision.utils import make_grid
-
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
- # ===================================================== Load Data =====================================================
- training_data = datasets.CIFAR10(root="data", train=True, download=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
- ]))
-
- validation_data = datasets.CIFAR10(root="data", train=False, download=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
- ]))
- # 计算 方差
- data_variance = np.var(training_data.data / 255.0)
-
-
-
- # ===================================================== Vector Quantizer Layer =====================================================
- """
- This layer takes a tensor to be quantized.
- The channel dimension will be used as the space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.
- The output tensor will have the same shape as the input.
- As an example for a BCHW tensor of shape [16, 64, 32, 32], we will first convert it to an BHWC tensor of shape [16, 32, 32, 64] and then reshape it into [16384, 64] and all 16384 vectors of size 64 will be quantized independently.
- In otherwords, the channels are used as the space in which to quantize. All other dimensions will be flattened and be seen as different examples to quantize, 16384 in this case.
- """
-
- class VectorQuantizer(nn.Module):
- def __init__(self, num_embeddings, embedding_dim, commitment_cost):
- '''
- :param num_embeddings: codebook的大小
- :param embedding_dim: codebook中每个vector的维度
- :param commitment_cost: commit loss的β
- '''
-
- super(VectorQuantizer, self).__init__()
-
- self._embedding_dim = embedding_dim
- self._num_embeddings = num_embeddings
- # 构建一个codebook,用均匀分布对codebook的权重进行初始化
- self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
- self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
- self._commitment_cost = commitment_cost
-
- def forward(self, inputs):
-
- # convert inputs(encoder's output) from BCHW -> BHWC
- inputs = inputs.permute(0, 2, 3, 1).contiguous()
- input_shape = inputs.shape
-
- # Flatten input
- flat_input = inputs.view(-1, self._embedding_dim)
-
- # Calculate the Z_e(x) and e distances
- # 这里使用欧几里得距离的平方求距离distance(16384,512)
- distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
- + torch.sum(self._embedding.weight ** 2, dim=1)
- - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
-
- # 通过distance得到距离最近的index
- encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
- # 转为one-hot格式
- encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
- encodings.scatter_(1, encoding_indices, 1)
-
- # Quantize and unflatten:得到最近邻的Embedding Vector
- quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
-
- # Loss
- # commit loss
- e_latent_loss = F.mse_loss(quantized.detach(), inputs)
- # codebook loss
- q_latent_loss = F.mse_loss(quantized, inputs.detach())
- loss = q_latent_loss + self._commitment_cost * e_latent_loss
- # trick(梯度复制),通过添加一个常数让编码器和解码器连续可导
- quantized = inputs + (quantized - inputs).detach()
- # 利用困惑度监测分布,困惑度越大,信息熵也就越大,分布就没有这么均匀
- avg_probs = torch.mean(encodings, dim=0)
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
-
- # convert quantized from BHWC -> BCHW
- return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
-
-
- """
- We will also implement a slightly modified version which will use exponential moving averages to update the embedding vectors instead of an auxillary loss.
- This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other parts of the architecture.
- For most experiments the EMA version trains faster than the non-EMA version.
- """
-
- class VectorQuantizerEMA(nn.Module):
- def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
- super(VectorQuantizerEMA, self).__init__()
-
- self._embedding_dim = embedding_dim
- self._num_embeddings = num_embeddings
-
- self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
- self._embedding.weight.data.normal_()
- self._commitment_cost = commitment_cost
-
- self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
- self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
- self._ema_w.data.normal_()
-
- self._decay = decay
- self._epsilon = epsilon
-
- def forward(self, inputs):
- # convert inputs from BCHW -> BHWC
- inputs = inputs.permute(0, 2, 3, 1).contiguous()
- input_shape = inputs.shape
-
- # Flatten input
- flat_input = inputs.view(-1, self._embedding_dim)
-
- # Calculate distances
- distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
- + torch.sum(self._embedding.weight ** 2, dim=1)
- - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
-
- # Encoding
- encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
- encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
- encodings.scatter_(1, encoding_indices, 1)
-
- # Quantize and unflatten
- quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
-
- # Use EMA to update the embedding vectors
- if self.training:
- self._ema_cluster_size = self._ema_cluster_size * self._decay + \
- (1 - self._decay) * torch.sum(encodings, 0)
-
- # Laplace smoothing of the cluster size
- n = torch.sum(self._ema_cluster_size.data)
- self._ema_cluster_size = (
- (self._ema_cluster_size + self._epsilon)
- / (n + self._num_embeddings * self._epsilon) * n)
-
- dw = torch.matmul(encodings.t(), flat_input)
- self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
-
- self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
-
- # Loss
- e_latent_loss = F.mse_loss(quantized.detach(), inputs)
- loss = self._commitment_cost * e_latent_loss
-
- # Straight Through Estimator
- quantized = inputs + (quantized - inputs).detach()
- avg_probs = torch.mean(encodings, dim=0)
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
-
- # convert quantized from BHWC -> BCHW
- return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
-
-
- # ===================================================== Encoder & Decoder Architecture =====================================================
- class Residual(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
- super(Residual, self).__init__()
- self._block = nn.Sequential(
- nn.ReLU(True),
- nn.Conv2d(in_channels=in_channels,
- out_channels=num_residual_hiddens,
- kernel_size=3, stride=1, padding=1, bias=False),
- nn.ReLU(True),
- nn.Conv2d(in_channels=num_residual_hiddens,
- out_channels=num_hiddens,
- kernel_size=1, stride=1, bias=False)
- )
-
- def forward(self, x):
- return x + self._block(x)
-
-
- class ResidualStack(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
- super(ResidualStack, self).__init__()
- self._num_residual_layers = num_residual_layers
- self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
- for _ in range(self._num_residual_layers)])
-
- def forward(self, x):
- for i in range(self._num_residual_layers):
- x = self._layers[i](x)
- return F.relu(x)
-
-
- class Encoder(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
- super(Encoder, self).__init__()
-
- self._conv_1 = nn.Conv2d(in_channels=in_channels,
- out_channels=num_hiddens // 2,
- kernel_size=4,
- stride=2, padding=1)
- self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2,
- out_channels=num_hiddens,
- kernel_size=4,
- stride=2, padding=1)
- self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
- out_channels=num_hiddens,
- kernel_size=3,
- stride=1, padding=1)
- self._residual_stack = ResidualStack(in_channels=num_hiddens,
- num_hiddens=num_hiddens,
- num_residual_layers=num_residual_layers,
- num_residual_hiddens=num_residual_hiddens)
-
- def forward(self, inputs):
- x = self._conv_1(inputs)
- x = F.relu(x)
-
- x = self._conv_2(x)
- x = F.relu(x)
-
- x = self._conv_3(x)
- return self._residual_stack(x)
-
-
- class Decoder(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
- super(Decoder, self).__init__()
-
- self._conv_1 = nn.Conv2d(in_channels=in_channels,
- out_channels=num_hiddens,
- kernel_size=3,
- stride=1, padding=1)
-
- self._residual_stack = ResidualStack(in_channels=num_hiddens,
- num_hiddens=num_hiddens,
- num_residual_layers=num_residual_layers,
- num_residual_hiddens=num_residual_hiddens)
-
- self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
- out_channels=num_hiddens // 2,
- kernel_size=4,
- stride=2, padding=1)
-
- self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
- out_channels=3,
- kernel_size=4,
- stride=2, padding=1)
-
- def forward(self, inputs):
- x = self._conv_1(inputs)
-
- x = self._residual_stack(x)
-
- x = self._conv_trans_1(x)
- x = F.relu(x)
-
- return self._conv_trans_2(x)
-
-
- # ===================================================== Train =====================================================
- batch_size = 256
- num_training_updates = 15000
-
- num_hiddens = 128
- num_residual_hiddens = 32
- num_residual_layers = 2
-
- # codebook的维度
- embedding_dim = 64
- num_embeddings = 512
-
- commitment_cost = 0.25
-
- decay = 0.99
- decay = 0
- learning_rate = 1e-3
-
- training_loader = DataLoader(training_data,
- batch_size=batch_size,
- shuffle=True,
- pin_memory=True)
- validation_loader = DataLoader(validation_data,
- batch_size=32,
- shuffle=True,
- pin_memory=True)
-
-
- class Model(nn.Module):
- def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay=0):
- super(Model, self).__init__()
-
- self._encoder = Encoder(3, num_hiddens, num_residual_layers, num_residual_hiddens)
- # 对encoder的输出进行后处理,得到与embedding table一样大小的维度
- self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
- out_channels=embedding_dim,
- kernel_size=1,
- stride=1)
- if decay > 0.0:
- self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
- else:
- self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
-
- self._decoder = Decoder(embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
-
- def forward(self, x):
- z = self._encoder(x)
- z = self._pre_vq_conv(z)
- loss, quantized, perplexity, _ = self._vq_vae(z)
- x_recon = self._decoder(quantized)
-
- return loss, x_recon, perplexity
-
-
- model = Model(num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay).to(device)
- optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)
- model.train()
- train_res_recon_error = []
- train_res_perplexity = []
-
- for i in xrange(num_training_updates):
- (data, _) = next(iter(training_loader))
- data = data.to(device)
- optimizer.zero_grad()
-
- vq_loss, data_recon, perplexity = model(data)
- recon_error = F.mse_loss(data_recon, data) / data_variance
- loss = recon_error + vq_loss
- loss.backward()
-
- optimizer.step()
-
- train_res_recon_error.append(recon_error.item())
- train_res_perplexity.append(perplexity.item())
-
- if (i + 1) % 100 == 0:
- print('%d iterations' % (i + 1))
- print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
- print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
- print()
-
pytorch-vq-vae/vq-vae.ipynb at master · zalandoresearch/pytorch-vq-vae · GitHub
因此AutoEncoder、VAE和VQ-VAE可以统一为latent code的概率分布设计不一样,AutoEncoder通过网络学习得到任意概率分布,VAE设计为正态分布,VQVAE设计为codebook的离散分布。总之,AutoEncoder的重构思想就是用低纬度的latent code分布来表达高纬度的数据分布,VAE和VQVAE的重构思想是通过设计latent code的分布形式,进而控制图片生成的过程。