• diffusion model(一):DDPM技术小结 (denoising diffusion probabilistic)


    发布日期:2023/05/18
    主页地址:http://myhz0606.com/article/ddpm

    1 从直觉上理解DDPM

    在详细推到公式之前,我们先从直觉上理解一下什么是扩散

    对于常规的生成模型,如GAN,VAE,它直接从噪声数据生成图像,我们不妨记噪声数据为z,其生成的图片为x

    对于常规的生成模型

    学习一个解码函数(即我们需要学习的模型)p,实现 p(z)=x

    (1)zpx

    常规方法只需要一次预测即能实现噪声到目标的映射,虽然速度快,但是效果不稳定。

    常规生成模型的训练过程(以VAE为例)

    (2)xqzpx^

    对于diffusion model

    它将噪声到目标的过程进行了多步拆解。不妨假设一共有T+1个时间步,第T个时间步 xT是噪声数据,第0个时间步的输出是目标图片x0。其过程可以表述为:

    (3)z=xTpxT1ppx1px0

    对于DDPM它采用的是一种自回归式的重建方法,每次的输入是当前的时刻及当前时刻的噪声图片。也就是说它把噪声到目标图片的生成分成了T步,这样每一次的预测相当于是对残差的预测。优势是重建效果稳定,但速度较慢。

    训练整体pipeline包含两个过程

    2 diffusion pipeline

    2.1前置知识:

    高斯分布的一些性质

    (1)如果XN(μ,σ2),且ab是实数,那么aX+bN(aμ+b,(aσ)2)

    (2)如果XN(μ(x),σ2(x)),YN(μ(y),σ2(y)),X,Y是统计独立的正态随机变量,则它们的和也满足高斯分布(高斯分布可加性).

    (4)X+YN(μ(x)+μ(y),σ2(x)+σ2(y))XYN(μ(x)μ(y),σ2(x)+σ2(y))

    均值为μ方差为σ的高斯分布的概率密度函数为

    f(x)=12πσexp((xμ)22σ2)(5)=12πσexp[12(1σ2x22μσ2x+μ2σ2)]

    2.2 加噪过程

    1 前向过程:将图片数据映射为噪声

    每一个时刻都要添加高斯噪声,后一个时刻都是由前一个时刻加噪声得到。(其实每一个时刻加的噪声就是训练所用的标签)。即

    (6)x0qx1qx2qqxT1qxT=z

    下面我们详细来看

    βt=1αtβtt的增加而增大(论文中[2]从0.0001 -> 0.02) (这是因为一开始加一点噪声就很明显,后面需要增大噪声的量才明显).DDPM将加噪声过程建模为一个马尔可夫过程q(x1:T|x0):=t=1Tq(xt|xt1),其中q(xt|xt1):=N(xt;αtxt1,(1αt)I)

    xt=αtxt1+(1αt)zt(7)=αtxt1+βtzt

    xt为在t时刻的图片,当t=0时为原图;zt为在t时刻所加的噪声,服从标准正态分布ztN(0,I);αt是常数,是自己定义的变量;从上式可见,随着T增大,xt越来越接近纯高斯分布.

    同理:

    (8)xt1=αt1xt2+1αt1zt1

    将式(8)代入式(7)可得:

    xt=αt(αt1xt2+1αt1zt1)+1αtzt(9)=αtαt1xt2+(αt(1αt1)zt1+1αtzt)

    由于zt1服从均值为0,方差为1的高斯分布(即标准正态分布),根据定义αt(1αt1)zt1服从的是均值为0,方差为αt(1αt1)的高斯分布.即αt(1αt1)zt1N(0,αt(1αt1)I).同理可得1αtztN(0,(1αt)I).则(高斯分布可加性,可以通过定义推得,不赘述)

    (10)(αt(1αt1),zt1+1αtzt)N(0,αt(1αt1)+1αt)=N(0,1αtαt1)

    我们不妨记zt2N(0,I),1αtαt1zt2N(0,(1αtαt1)I)则式(10)最终可改写为

    (11)xt=αtαt1xt2+1αtαt1zt2

    通过递推,容易得到

    xt=αtαt1α1x0+1αtαt1α1z0=i=1tαix0+1i=1tαiz0(12)=αt=i=1tαiαtx0+1αtz0

    其中z0N(0,I),x0为原图.从式(13)可见,我们可以从x0得到任意时刻的xt的分布,而无需按照时间顺序递推!这极大提升了计算效率.

    q(xt|x0)=N(xt;μ(xt,t),σ2(xt,t)I)(13)=N(xt;αtx0,(1αt)I)

    ⚠️加噪过程是确定的,没有模型的介入. 其目的是制作训练时标签

    2.3 去噪过程

    给定xT如何求出x0呢?直接求解是很难的,作者给出的方案是:我们可以一步一步求解.即学习一个解码函数p,这个p能够知道xtxt1的映射规则.如何定义这个p是问题的关键.有了p,只需从xtxt1逐步迭代,即可得出x0.

    (14)z=xTpxT1ppx1px0

    去噪过程是加噪过程的逆向.如果说加噪过程是求给定初始分布x0求任意时刻的分布xt,即q(xt|x0)那么去噪过程所求的分布就是给定任意时刻的分布xt求其初始时刻的分布x0,p(x0|xt) ,通过马尔可夫假设,可以对上述问题进行化简

    p(x0|xt)=p(x0|x1)p(x1|x2)p(xt1|xt)(15)=i=0t1p(xi|xi+1)

    如何求p(xt1|xt)呢?前面的加噪过程我们大力气推到出了q(xt|xt1),我们可以通过贝叶斯公式把它利用起来

    (16)p(xt1|xt)=p(xt|xt1)p(xt1)p(xt)

    ⚠️这里的(去噪)p和上面的(加噪)q只是对分布的一种符号记法。

    有了式(17)还是一头雾水,p(xt)p(xt1)都不知道啊!该怎么办呢?这就要借助模型的威力了.下面来看如何构建我们的模型.

    延续加噪过程的推导p(xt|x0)p(xt1|x0)我们是可以知道的.因此若我们知道初始分布x0,则

    p(xt1|xt,x0)=p(xt|xt1,x0)p(xt1|x0)p(xt|x0)(17)=N(xt;αtxt1,(1αt)I)N(xt1;αt1x0,(1αt1)I)N(xt;αtx0,(1αt)I)(18)(5)exp((xtαtxt1)22(1αt))exp((xt1αt1x0)22(1αt1))exp((xtαtx0)22(1αt))(19)=exp[12((xtαtxt1)21αt+(xt1αt1x0)21αt1(xtαtx0)21αt)](20)=exp[12((αt1αt+11αt1)xt12(2αt1αtxt+2αt11αt1x0)xt1+C(xt,x0))](21)

    结合高斯分布的定义(6)来看式(22),不难发现p(xt1|xt,x0)也是服从高斯分布的.并且结合式(6)我们可以求出其方差和均值

    ⚠️式17做了一个近似p(xt|xt1,x0)=p(xt|xt1),能做这个近似原因是一阶马尔科夫假设,当前时间点只依赖前一个时刻的时间点.

    1σ2=αt1αt+11αt1(22)2μσ2=2αt1αtxt+2αt11αt1x0(23)

    可以求得:

    σ2=1αt11αt(1αt)(24)μ=αt(1αt1)1αtxt+αt1(1αt)1αtx0

    通过上式,我们可得

    (25)p(xt1|xt,x0)=N(xt1;αt(1αt1)1αtxt+αt1(1αt)1αtx0,(1αt11αt(1αt))I)

    该式是真实的条件分布.我们目标是让模型学到的条件分布pθ(xt1|xt)尽可能的接近真实的条件分布p(xt1|xt,x0).从上式可以看到方差是个固定量,那么我们要做的就是让p(xt1|xt,x0)pθ(xt1|xt)的均值尽可能的对齐,即

    (这个结论也可以通过最小化上述两个分布的KL散度推得)

    (26)argminθu(x0,xt),uθ(xt,t)

    下面的问题变为:如何构造uθ(xt,t)来使我们的优化尽可能的简单

    我们注意到μ(x0,xt)μθ(xt,t)都是关于xt的函数,不妨让他们的xt保持一致,则可将μθ(xt,t)写成

    (27)μθ(xt,t)=αt(1αt1)1αtxt+αt1(1αt)1αtfθ(xt,t)

    fθ(xt,t)是我们需要训练的模型.这样对齐均值的问题就转化成了: 给定xt,t来预测原始图片输入x0.根据上文的加噪过程,我们可以很容易制造训练所需的数据对! (Dalle2的训练采用的是这个方式).事情到这里就结束了吗?

    DDPM作者表示直接从xtx0的预测数据跨度太大了,且效果一般.我们可以将式(12)做一下变形

    xt=αtx0+1αtz0(28)x0=1αt(xt1αtz0)

    代入到式(24)中

    μ=αt(1αt1)1αtxt+αt1(1αt)1αt1at(xt1atz0)=αt(1αt1)1αtxt+(1αt)1αt1αt(xt1αtz0)=xtαt(1αt1)+(1αt)αt(1αt)xt1αt(1αt)αt(1αt)z0=1αtαt(1αt)xt1αtαt1αtz0(29)=1αtxt1αtαt1αtz0

    经过这次化简,我们将μ(x0,xt)μ(xt,z0),其中z0N(0,I),可以将式(29)转变为

    (30)μθ(xt,t)=1αtxt1αtαt1αtfθ(xt,t)

    此时对齐均值的问题就转化成:给定xt,t预测xt加入的噪声z0, 也就是说我们的模型预测的是噪声fθ(xt,t)=ϵθ(xt,t)z0

    2.3.1 训练与采样过程

    训练的目标就是这所有时刻两个噪声的差异的期望越小越好(用MSE或L1-loss).

    (31)EtTϵϵθ(xt,t)22

    下图为论文提供的训练和采样过程

    image

    2.3.2 采样过程

    通过以上讨论,我们推导出pθ(xt1|xt)高斯分布的均值和方差.pθ(xt1|xt)=N(xt1;μθ(xt,t),σ2(t)I),根据文献[1]从一个高斯分布中采样一个随机变量可用一个重参数化技巧进行近似

    xt1=μθ(xt,t)+σ(t)ϵ,ϵN(ϵ;0,I)(32)=1αt(xt1αt1αtϵθ(xt,t))+σ(t)ϵ

    式(32)和论文给出的采样递推公式一致.

    至此,已完成DDPM整体的pipeline.

    还没想明白的点,为什么不能根据(7)的变形来进行采样计算呢?

    (33)xt1=1αtxt1αtαtfθ(xt,t)

    3 从代码理解训练&预测过程

    3.1 训练过程

    参考代码仓库: https://github.com/lucidrains/denoising-diffusion-pytorch/tree/main/denoising_diffusion_pytorch

    已知项: 我们假定有一批N张图片{xi|i=1,2,,N}

    第一步: 随机采样K组成batch,如x_start={xk|k=1,2,,K},Shape(x_start)=(K,C,H,W)

    第二步: 随机采样一些时间步

    t = torch.randint(0, self.num_timesteps, (b,), device=device).long()  # 随机采样时间步
    

    第三步: 随机采样噪声

    noise = default(noise, lambda: torch.randn_like(x_start))  # 基于高斯分布采样噪声
    

    第四步: 计算x_start在所采样的时间步的输出xT(即加噪声).(根据公式12)

    def linear_beta_schedule(timesteps):
        scale = 1000 / timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
    
    betas = linear_beta_schedule(timesteps)
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    
    def extract(a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    
    def q_sample(x_start, t, noise=None):
      """
      \begin{eqnarray}
        x_t &=& \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \nonumber \\
        &=&  \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t
      \end{eqnarray}
      """
        return (
            extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
    
    x = q_sample(x_start = x_start, t = t, noise = noise)  # 这就是x0在时间步T的输出
    

    第五步: 预测噪声.输入xT,t到噪声预测模型,来预测此时的噪声z^t=ϵθ(xT,t).论文用到的模型结构是Unet,与传统Unet的输入有所不同的是增加了一个时间步的输入.

    model_out = self.model(x, t, x_self_cond=None)  # 预测噪声
    

    这里面有一个需要注意的点:模型是如何对时间步进行编码并使用的

    • 首先会对时间步进行一个编码,将其变为一个向量,以正弦编码为例
    class SinusoidalPosEmb(nn.Module):
        def __init__(self, dim):
            super().__init__()
            self.dim = dim
    
        def forward(self, x):
            """
            Args:
              x (Tensor), shape like (B,)
            """
            device = x.device
            half_dim = self.dim // 2
            emb = math.log(10000) / (half_dim - 1)
            emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
            emb = x[:, None] * emb[None, :]
            emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
            return emb
    
    # 时间步的编码pipeline如下,本质就是将一个常数映射为一个向量
    self.time_mlp = nn.Sequential(
        SinusoidalPosEmb(dim),
        nn.Linear(fourier_dim, time_dim),
        nn.GELU(),
        nn.Linear(time_dim, time_dim)
    )
    
    • 将时间步的embedding嵌入到Unet的block中,使模型能够学习到时间步的信息
    class Block(nn.Module):
        def __init__(self, dim, dim_out, groups = 8):
            super().__init__()
            self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
            self.norm = nn.GroupNorm(groups, dim_out)
            self.act = nn.SiLU()
    
        def forward(self, x, scale_shift = None):
            x = self.proj(x)
            x = self.norm(x)
    
            if exists(scale_shift):
                scale, shift = scale_shift
                x = x * (scale + 1) + shift  # 将时间向量一分为2,一份用于提升幅值,一份用于修改相位
    
            x = self.act(x)
            return x
    
    class ResnetBlock(nn.Module):
        def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
            super().__init__()
            self.mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, dim_out * 2)
            ) if exists(time_emb_dim) else None
    
            self.block1 = Block(dim, dim_out, groups = groups)
            self.block2 = Block(dim_out, dim_out, groups = groups)
            self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    
        def forward(self, x, time_emb = None):
    
            scale_shift = None
            if exists(self.mlp) and exists(time_emb):
                time_emb = self.mlp(time_emb)
                time_emb = rearrange(time_emb, 'b c -> b c 1 1')
                scale_shift = time_emb.chunk(2, dim = 1)
    
            h = self.block1(x, scale_shift = scale_shift)
    
            h = self.block2(h)
    
            return h + self.res_conv(x)
    

    第六步:计算损失,反向传播.计算预测的噪声与实际的噪声的损失,损失函数可以是L1或mse

    @property
        def loss_fn(self):
            if self.loss_type == 'l1':
                return F.l1_loss
            elif self.loss_type == 'l2':
                return F.mse_loss
            else:
                raise ValueError(f'invalid loss type {self.loss_type}')
    
    

    通过不断迭代上述6步即可完成模型的训练

    3.2采样过程

    第一步:随机从高斯分布采样一张噪声图片,并给定采样时间步

    img = torch.randn(shape, device=device)
    

    第二步: 根据预测的当前时间步的噪声,通过公式计算当前时间步的均值和方差

    
      posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) # 式(24)x_0的系数
      posterior_mean_coef = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)  # 式(24) x_t的系数
    
      def extract(a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    
      def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )  # 求出此时的均值
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)  # 求出此时的方差
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) # 对方差取对数,可能为了数值稳定性
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
      def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
          preds = self.model_predictions(x, t, x_self_cond)  # 预测噪声
          x_start = preds.pred_x_start  # 模型预测的是在x_t时间步噪声,x_start是根据公式(12)求
    
          if clip_denoised:
              x_start.clamp_(-1., 1.)
    
          model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
          return model_mean, posterior_variance, posterior_log_variance, x_start
    
    

    第三步: 根据公式(32)计算得到前一个时刻图片xt1

      @torch.no_grad()
      def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
          b, *_, device = *x.shape, x.device
          batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
          model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)  # 计算当前分布的均值和方差
          noise = torch.randn_like(x) if t > 0 else 0. # 从高斯分布采样噪声
          pred_img = model_mean + (0.5 * model_log_variance).exp() * noise  # 根据
          return pred_img, x_start
    

    通过迭代以上三步,直至T=0完成采样.

    思考和讨论

    DDPM区别与传统的VAE与GAN采用了一种新的范式实现了更高质量的图像生成.但实践发现,需要较大的采样步数才能得到较好的生成结果.由于其采样过程是一个马尔可夫的推理过程,导致会有较大的耗时.后续工作如DDIM针对该特性做了优化,数十倍降低采样所用时间。

    参考文献

    [1]  Understanding Diffusion Models: A Unified Perspective

    [2]  Denoising Diffusion Probabilistic Models

  • 相关阅读:
    图论例题解析
    这可能是你进腾讯最后的机会了..
    Elasticsearch:使用 Amazon Bedrock 的 semantic_text
    CSDN更换背景 调整博客风格与代码块样式方法
    springboot幼儿园幼儿基本信息管理系统设计与实现毕业设计源码201126
    五月集训(第二十三日)字典树
    从零开始,开发一个 Web Office 套件(11):支持中文输入法(or 其它使用输入法的语言)
    浅析量化交易是什么类型的交易?
    【Markdown语法高级】让你的博客更精彩(四:设置字体样式以及颜色对照表)
    apifox的使用以及和idea集成
  • 原文地址:https://www.cnblogs.com/myhz/p/18210120