发布日期:2023/05/18
主页地址:http://myhz0606.com/article/ddpm
1 从直觉上理解DDPM
在详细推到公式之前,我们先从直觉上理解一下什么是扩散
对于常规的生成模型,如GAN,VAE,它直接从噪声数据生成图像,我们不妨记噪声数据为
对于常规的生成模型:
学习一个解码函数(即我们需要学习的模型)p,实现
常规方法只需要一次预测即能实现噪声到目标的映射,虽然速度快,但是效果不稳定。
常规生成模型的训练过程(以VAE为例)
对于diffusion model
它将噪声到目标的过程进行了多步拆解。不妨假设一共有
对于DDPM它采用的是一种自回归式的重建方法,每次的输入是当前的时刻及当前时刻的噪声图片。也就是说它把噪声到目标图片的生成分成了T步,这样每一次的预测相当于是对残差的预测。优势是重建效果稳定,但速度较慢。
训练整体pipeline包含两个过程
2 diffusion pipeline
2.1前置知识:
高斯分布的一些性质
(1)如果
(2)如果
均值为
2.2 加噪过程
1 前向过程:将图片数据映射为噪声
每一个时刻都要添加高斯噪声,后一个时刻都是由前一个时刻加噪声得到。(其实每一个时刻加的噪声就是训练所用的标签)。即
下面我们详细来看
记
同理:
将式(8)代入式(7)可得:
由于
我们不妨记
通过递推,容易得到
其中
⚠️加噪过程是确定的,没有模型的介入. 其目的是制作训练时标签
2.3 去噪过程
给定
去噪过程是加噪过程的逆向.如果说加噪过程是求给定初始分布
如何求
⚠️这里的(去噪)
有了式(17)还是一头雾水,
延续加噪过程的推导
结合高斯分布的定义(6)来看式(22),不难发现
⚠️式17做了一个近似
可以求得:
通过上式,我们可得
该式是真实的条件分布.我们目标是让模型学到的条件分布
(这个结论也可以通过最小化上述两个分布的KL散度推得)
下面的问题变为:如何构造
我们注意到
DDPM作者表示直接从
代入到式(24)中
经过这次化简,我们将
此时对齐均值的问题就转化成:给定
2.3.1 训练与采样过程
训练的目标就是这所有时刻两个噪声的差异的期望越小越好(用MSE或L1-loss).
下图为论文提供的训练和采样过程

2.3.2 采样过程
通过以上讨论,我们推导出
式(32)和论文给出的采样递推公式一致.
至此,已完成DDPM整体的pipeline.
还没想明白的点,为什么不能根据(7)的变形来进行采样计算呢?
3 从代码理解训练&预测过程
3.1 训练过程
参考代码仓库: https://github.com/lucidrains/denoising-diffusion-pytorch/tree/main/denoising_diffusion_pytorch
已知项: 我们假定有一批N张图片
第一步: 随机采样K组成batch,如
第二步: 随机采样一些时间步
t = torch.randint(0, self.num_timesteps, (b,), device=device).long() # 随机采样时间步
第三步: 随机采样噪声
noise = default(noise, lambda: torch.randn_like(x_start)) # 基于高斯分布采样噪声
第四步: 计算
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def q_sample(x_start, t, noise=None):
"""
\begin{eqnarray}
x_t &=& \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \nonumber \\
&=& \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t
\end{eqnarray}
"""
return (
extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
x = q_sample(x_start = x_start, t = t, noise = noise) # 这就是x0在时间步T的输出
第五步: 预测噪声.输入
model_out = self.model(x, t, x_self_cond=None) # 预测噪声
这里面有一个需要注意的点:模型是如何对时间步进行编码并使用的
- 首先会对时间步进行一个编码,将其变为一个向量,以正弦编码为例
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
"""
Args:
x (Tensor), shape like (B,)
"""
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# 时间步的编码pipeline如下,本质就是将一个常数映射为一个向量
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
- 将时间步的embedding嵌入到Unet的block中,使模型能够学习到时间步的信息
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift # 将时间向量一分为2,一份用于提升幅值,一份用于修改相位
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
第六步:计算损失,反向传播.计算预测的噪声与实际的噪声的损失,损失函数可以是L1或mse
@property
def loss_fn(self):
if self.loss_type == 'l1':
return F.l1_loss
elif self.loss_type == 'l2':
return F.mse_loss
else:
raise ValueError(f'invalid loss type {self.loss_type}')
通过不断迭代上述6步即可完成模型的训练
3.2采样过程
第一步:随机从高斯分布采样一张噪声图片,并给定采样时间步
img = torch.randn(shape, device=device)
第二步: 根据预测的当前时间步的噪声,通过公式计算当前时间步的均值和方差
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) # 式(24)x_0的系数
posterior_mean_coef = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) # 式(24) x_t的系数
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
) # 求出此时的均值
posterior_variance = extract(self.posterior_variance, t, x_t.shape) # 求出此时的方差
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) # 对方差取对数,可能为了数值稳定性
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond) # 预测噪声
x_start = preds.pred_x_start # 模型预测的是在x_t时间步噪声,x_start是根据公式(12)求
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
第三步: 根据公式(32)计算得到前一个时刻图片
@torch.no_grad()
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised) # 计算当前分布的均值和方差
noise = torch.randn_like(x) if t > 0 else 0. # 从高斯分布采样噪声
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise # 根据
return pred_img, x_start
通过迭代以上三步,直至
思考和讨论
DDPM区别与传统的VAE与GAN采用了一种新的范式实现了更高质量的图像生成.但实践发现,需要较大的采样步数才能得到较好的生成结果.由于其采样过程是一个马尔可夫的推理过程,导致会有较大的耗时.后续工作如DDIM针对该特性做了优化,数十倍降低采样所用时间。