• 深度学习——(生成模型)DDPM


    前置数学知识

    1、先验概率和后验概率

    先验概率:根据以往经验和分析得到的概率,它往往作为“由因求果”问题中的“因”出现,如 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1)

    后验概率:指在得到“结果”的信息后重新修正的概率,是“执果寻因”问题中的“因", 如 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)

    2、条件概率:设 A A A B B B为任意两个事件,若 P ( A ) > 0 P(A)>0 P(A)>0,称在已知事件 A A A发生的条件下,事件 B B B发生的概率为条件概率,记为 P ( B ∣ A ) P(B|A) P(BA)
    P ( B ∣ A ) = P ( A , B ) P ( A ) P(B|A)=\frac{P(A,B)} {P(A)} P(BA)=P(A)P(A,B)

    3、乘法公式:
    P ( A , B ) = P ( B ∣ A ) P ( A ) P(A,B)=P(B|A)P(A) P(A,B)=P(BA)P(A)

    4、乘法公式一般形式:
    P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B , A ) P ( B ∣ A ) P ( A ) P(A,B,C)=P(C|B,A)P(B,A)=P(C|B,A)P(B|A)P(A)\\ P(A,B,C)=P(CB,A)P(B,A)=P(CB,A)P(BA)P(A)

    5、贝叶斯公式:
    P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A|B)=\frac{P(B|A)P(A)}{P(B)} P(AB)=P(B)P(BA)P(A)
    6、多元贝叶斯公式:
    P ( A ∣ B , C ) = P ( A , B , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A ∣ C ) P ( C ) P ( B ∣ C ) P ( C ) = P ( B ∣ A , C ) P ( A ∣ C ) ) P ( B ∣ C ) P(A|B,C)=\frac{P(A,B,C)}{P(B,C)}=\frac{P(B|A,C)P(A,C)}{P(B,C)}=\frac{P(B|A,C)P(A|C)P(C)}{P(B|C)P(C)}=\frac{P(B|A,C)P(A|C))}{P(B|C)} P(AB,C)=P(B,C)P(A,B,C)=P(B,C)P(BA,C)P(A,C)=P(BC)P(C)P(BA,C)P(AC)P(C)=P(BC)P(BA,C)P(AC))

    7、正态分布的叠加性:当有两个独立的正态分布变量 N 1 N_{1} N1 N 2 N_{2} N2,它们的均值和方差分别为 μ 1 \mu_{1} μ1, μ 2 \mu_{2} μ2 σ 1 2 \sigma_{1}^2 σ12, σ 2 2 \sigma_{2}^2 σ22它们的和为 N = a N 1 + b N 2 N=a N_{1}+b N_{2} N=aN1+bN2的均值和方差可以表示如下:
    E ( N ) = E ( a N 1 + b N 2 ) = a μ 1 + b μ 2 V a r ( N ) = V a r ( a N 1 + b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}+bN_{2})=a\mu_{1}+b\mu_{2}\\ Var(N)=Var(aN_{1}+bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1+bN2)=aμ1+bμ2Var(N)=Var(aN1+bN2)=a2σ12+b2σ22
    相减时:
    E ( N ) = E ( a N 1 − b N 2 ) = a μ 1 − b μ 2 V a r ( N ) = V a r ( a N 1 − b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}-bN_{2})=a\mu_{1}-b\mu_{2}\\ Var(N)=Var(aN_{1}-bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1bN2)=aμ1bμ2Var(N)=Var(aN1bN2)=a2σ12+b2σ22

    8、重参数化:从 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2) 采样等价于从 N ( 0 , 1 ) N(0,1) N(0,1)采样一个 ϵ \epsilon ϵ, ϵ ⋅ σ + μ \epsilon\cdot\sigma+\mu ϵσ+μ

    9、高斯分布的概率密度函数
    f ( x ) = 1 2 π σ e − ( x − μ ) 2 2 σ 2 f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}} f(x)=2π σ1e2σ2(xμ)2
    10、高斯分布的KL散度公式
    K L ( p ∣ q ) = l o g σ 2 σ 1 + σ 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(p|q)=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} KL(pq)=logσ1σ2+2σ22σ2+(μ1μ2)221
    11、二次函数配方
    a x 2 + b x = a ( x + b 2 a ) 2 + c ax^2+bx=a(x+\frac{b}{2a})^2+c ax2+bx=a(x+2ab)2+c
    12、随机变量的期望公式
    X X X是随机变量, Y = g ( X ) Y=g(X) Y=g(X),则:
    E ( Y ) = E [ g ( X ) ] = { ∑ k = 1 ∞ g ( x k ) p k ∫ − ∞ ∞ g ( x ) p ( x ) d x E(Y)=E[g(X)]=

    {k=1g(xk)pkg(x)p(x)dx" role="presentation" style="position: relative;">{k=1g(xk)pkg(x)p(x)dx
    E(Y)=E[g(X)]= k=1g(xk)pkg(x)p(x)dx

    13、KL散度公式
    K L ( p ( x ) ∣ q ( x ) ) = E x ∼ p ( x ) [ p ( x ) q ( x ) ] = ∫ p ( x ) p ( x ) q ( x ) d x KL(p(x)|q(x))=E_{x \sim p(x)}[\frac{p(x)}{q(x)}]=\int p(x) \frac{p(x)}{q(x)}dx KL(p(x)q(x))=Exp(x)[q(x)p(x)]=p(x)q(x)p(x)dx

    介绍DDPM

    2020年Berkeley提出DDPM(Denoising Diffusion Probabilistic Models),简称扩散模型,是AIGC的核心算法,在生成图像的真实性和多样性方面均超越了GAN,而且训练过程稳定。缺点是计算成本较高,实时推理比较困难,但也有相关技术在时间和空间维度上降低计算量。

    扩散模型包括两个过程:前向扩散过程(前向加噪过程)反向去噪过程

    img

    前向过程和反向过程都是马尔可夫链,全过程大约需要1000步,其中反向过程用来生成数据,它的推导过程可以描述成:

    img

    前向扩散的过程

    前向扩散过程是对原始数据逐渐增加高斯噪声,直至变成标准高斯分布的过程。

    img

    从原始数据集采样 x 0 ∼ q ( x 0 ) x_0\sim q(x_0) x0q(x0),按照预定义的noise schedule策略添加随机噪声,得到一系列噪声图像 x 1 , x 2 , … , x T x_1,x_2,\dots,x_T x1,x2,,xT,用概率表示为:
    q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t I )

    q(x1:T|x0)=t=1Tq(xt|xt1)q(xt|xt1)=N(xt;αtxt1,βtI)" role="presentation" style="position: relative;">q(x1:T|x0)=t=1Tq(xt|xt1)q(xt|xt1)=N(xt;αtxt1,βtI)
    q(x1:Tx0)q(xtxt1)=t=1Tq(xtxt1)=N(xt;αt xt1,βtI)
    进行重参数化(前置知识数学知识8),得到
    x t = α t x t − 1 + β t ϵ t      ϵ t ∼ N ( 0 , I ) α t = 1 − β t
    xt=αtxt1+βtϵt    ϵtN(0,I)αt=1βt" role="presentation" style="position: relative;">xt=αtxt1+βtϵt    ϵtN(0,I)αt=1βt
    xtαt=αt xt1+βt ϵt    ϵtN(0,I)=1βt

    利用上述公式进行迭代推导
    x t = α t x t − 1 + β t ϵ t = α t ( α t − 1 x t − 2 + β t − 1 ϵ t − 1 ) + β t ϵ t = ( α t … α 1 ) x 0 + ( α t … α 2 ) β 1 ϵ 1 + ( α t … α 3 ) β 2 ϵ 2 + ⋯ + α t β t − 1 ϵ t − 1 + β t ϵ t

    xt=αtxt1+βtϵt=αt(αt1xt2+βt1ϵt1)+βtϵt=(αtα1)x0+(αtα2)β1ϵ1+(αtα3)β2ϵ2++αtβt1ϵt1+βtϵt" role="presentation" style="position: relative;">xt=αtxt1+βtϵt=αt(αt1xt2+βt1ϵt1)+βtϵt=(αtα1)x0+(αtα2)β1ϵ1+(αtα3)β2ϵ2++αtβt1ϵt1+βtϵt
    xt=αt xt1+βt ϵt=αt (αt1 xt2+βt1 ϵt1)+βt ϵt=(αtα1) x0+(αtα2)β1 ϵ1+(αtα3)β2 ϵ2++αtβt1 ϵt1+βt ϵt

    设: α t ˉ = α 1 α 2 … α t \bar{\alpha_{t}}=\alpha_{1}\alpha_{2}\dots\alpha_{t} αtˉ=α1α2αt

    根据正态分布的叠加性得到

    x t = α t ˉ x 0 + 1 − α t ˉ ϵ     ϵ ∼ N ( 0 , I ) q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , 1 − α t ˉ I ) x_{t}=\sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon \space \space\space \epsilon\sim \mathcal{N}(0,I)\\ \textcolor{REd}{q(x_{t}|x_{0})=\mathcal{N}(x_{t};\sqrt{\bar{\alpha_{t}}}x_{0},\sqrt{1-\bar{\alpha_{t}}}I)} xt=αtˉ x0+1αtˉ ϵ   ϵN(0,I)q(xtx0)=N(xt;αtˉ x0,1αtˉ I)
    这个公式表示任意步骤 t t t的噪声图像 x t x_t xt ,都可以通过 x 0 x_0 x0直接加噪得到,后面需要用到。

    注:上述前向过程在代码实现时是一步到位的!!!!!

    反向去噪过程,神经网络拟合过程

    反向去噪过程就是数据生成过程,它首先是从标准高斯分布中采样得到一个噪声样本,再一步步地迭代去噪,最后得到数据分布中的一个样本。

    img

    如果知道反向过程的每一步真实的条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt),那么从一个随机噪声开始,逐步采样就能生成一个真实的样本。但是真实的条件分布利用贝叶斯公式
    q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ) q ( x t ) q(x_{t-1}|x_{t}) =\frac{q(x_{t}|x_{t-1})q(x_{t-1})}{q(x_{t})} q(xt1xt)=q(xt)q(xtxt1)q(xt1)
    无法直接求解,原因是其中 q ( x t − 1 ) q(x_{t-1}) q(xt1) , q ( x t ) q(x_{t}) q(xt) 未知,因此无法从 x t x_{t} xt 推导到 x t − 1 {x_{t-1}} xt1,所以必须通过神经网络** p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt)来近似。为了简化起见,将反向过程也定义为一个马尔卡夫链,且服从高斯分布**,建模如下:

    p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{0:T})=p(x_T)\prod_{t=1}^Tp_\theta(x_{t-1}|x_t)\\ p_\theta(x_{t-1}|x_t)=N(x_{t-1};\mu_\theta(x_t,t),\sum_\theta(x_t,t)) pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))

    --------------------下面这段讲解与上面有些跳脱,是为损失函数做铺垫------------------------------

    虽然真实条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)无法直接求解,但是加上已知条件 x 0 x_0 x0的后验分布$q(x_{t-1}|x_{t},x_{0}) $却可以通过贝叶斯公式求解,再结合前向马尔科夫性质可得
    q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(x_{t-1}|x_{t},x_{0}) =\frac{q(x_{t}|x_{t-1},x_{0})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})}=\frac{q(x_{t}|x_{t-1})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})} q(xt1xt,x0)=q(xtx0)q(xtxt1,x0)q(xt1x0)=q(xtx0)q(xtxt1)q(xt1x0)
    因此可以得到:
    q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ ∼ N ( α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t ϵ ∼ N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q ( x t ∣ x t − 1 ) = α t x t − 1 + β t ϵ ∼ N ( α t x t − 1 , β t I )

    q(xt1|x0)=α¯t1x0+1α¯t1ϵN(α¯t1x0,(1α¯t1)I)q(xt|x0)=α¯tx0+1α¯tϵN(α¯tx0,(1α¯t)I)q(xt|xt1)=αtxt1+βtϵN(αtxt1,βtI)" role="presentation" style="position: relative;">q(xt1|x0)=α¯t1x0+1α¯t1ϵN(α¯t1x0,(1α¯t1)I)q(xt|x0)=α¯tx0+1α¯tϵN(α¯tx0,(1α¯t)I)q(xt|xt1)=αtxt1+βtϵN(αtxt1,βtI)
    q(xt1x0)q(xtx0)q(xtxt1)=αˉt1 x0+1αˉt1 ϵN(αˉt1 x0,(1αˉt1)I)=αˉt x0+1αˉt ϵN(αˉt x0,(1αˉt)I)=α txt1+βtϵN(α txt1,βtI)
    所以
    q ( x t − 1 ∣ x t , x 0 ) ∝ e x p ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t ) + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) = e x p ( − 1 2 ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α t ˉ 1 − α t ˉ x 0 ) x t − 1 + C ( x t , x 0 ) )
    q(xt1|xt,x0)exp(12((xtαtxt1)2βt)+(xt1α¯t1x0)21α¯t1(xtα¯tx0)21α¯t)=exp(12(αtβt+11α¯t1)xt12(2αtβtxt+2αt¯1αt¯x0)xt1+C(xt,x0))" role="presentation" style="position: relative;">q(xt1|xt,x0)exp(12((xtαtxt1)2βt)+(xt1α¯t1x0)21α¯t1(xtα¯tx0)21α¯t)=exp(12(αtβt+11α¯t1)xt12(2αtβtxt+2αt¯1αt¯x0)xt1+C(xt,x0))
    q(xt1xt,x0)exp(21(βt(xtαt xt1)2)+1αˉt1(xt1αˉ t1x0)21αˉt(xtαˉt x0)2)=exp(21(βtαt+1αˉt11)xt12(βt2αt xt+1αtˉ2αtˉ x0)xt1+C(xt,x0))

    通过配方就可以得到
    β ~ t = 1 / ( α t β t + 1 1 − α ˉ t − 1 ) = 1 − α ˉ t − 1 1 − α ˉ t β t μ ~ t = ( α t β t x t + α ˉ t 1 − α t ˉ x 0 ) / ( α t β t + 1 1 − α ˉ t − 1 ) = α t ( 1 − α ˉ t − 1 ) 1 − α t ˉ x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \widetilde{\beta}_t=1/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_{t}\\ \widetilde{\mu}_t=(\frac{\sqrt\alpha_{t}}{\beta_{t}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t}}}{1-\bar{\alpha_{t}}}x_{0})/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_{t}}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_{t}}{1-\bar{\alpha}_{t}}x_{0} β t=1/(βtαt+1αˉt11)=1αˉt1αˉt1βtμ t=(βtα txt+1αtˉαˉt x0)/(βtαt+1αˉt11)=1αtˉαt (1αˉt1)xt+1αˉtαˉt1 βtx0

    又因为
    x 0 = 1 α ˉ t ( x t − β t 1 − α ˉ t ϵ ) x_0= \frac{1}{\sqrt{\bar\alpha_t}}(x_t- \frac{\beta_t}{\sqrt{1-\bar \alpha_t} }\epsilon)\\ x0=αˉt 1(xt1αˉt βtϵ)
    可以得
    μ ~ t = 1 α t ( x t − β t ( 1 − α t ) ϵ ) \widetilde{\mu}_t=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon) μ t=αt 1(xt(1αt) βtϵ)

    ----------------------------------------------------------------------------------------------

    采样过程(模型训练完后的预测过程)

    μ θ ( x t , t ) = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) x t − 1 ∼ p θ ( x t − 1 ∣ x t ) x t − 1 = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) + β ~ t z      z ∼ N ( 0 , I ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))\\ x_{t-1}\sim p_\theta(x_{t-1}|x_t)\\ x_{t-1}=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))+\sqrt{\widetilde{\beta}_t}z \space \space\space\space z\sim N(0,I) μθ(xt,t)=αt 1(xt(1αt) βtϵθ(xt,t))xt1pθ(xt1xt)xt1=αt 1(xt(1αt) βtϵθ(xt,t))+β t z    zN(0,I)
    这里用z是为了和之前的 ϵ \epsilon ϵ区别开

    损失函数

    https://blog.csdn.net/weixin_45453121/article/details/131223653

    Code

    import torch
    import torchvision
    import matplotlib.pyplot as plt
    import torch.nn.functional as F
    from torchvision import transforms
    from torch.utils.data import DataLoader
    import numpy as np
    from torch.optim import Adam
    from torch import nn
    import math
    from torchvision.utils import save_image
    
    
    def show_images(data, num_samples=20, cols=4):
        """ Plots some samples from the dataset """
        plt.figure(figsize=(15,15))
        for i, img in enumerate(data):
            if i == num_samples:
                break
            plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
            plt.imshow(img[0])
    
    
    def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
        return torch.linspace(start, end, timesteps)
    
    
    def get_index_from_list(vals, t, x_shape):
        """
        Returns a specific index t of a passed list of values vals
        while considering the batch dimension.
        """
        batch_size = t.shape[0]
        out = vals.gather(-1, t.cpu())
        #print("out:",out)
        #print("out.shape:",out.shape)
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
    
    def forward_diffusion_sample(x_0, t, device="cpu"):
        """
        Takes an image and a timestep as input and
        returns the noisy version of it
        """
        noise = torch.randn_like(x_0)
        sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
            sqrt_one_minus_alphas_cumprod, t, x_0.shape
        )
        # mean + variance
        return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
        + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
    
    
    
    
    
    
    def load_transformed_dataset(IMG_SIZE):
        data_transforms = [
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(), # Scales data into [0,1]
            transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
        ]
        data_transform = transforms.Compose(data_transforms)
    
        train = torchvision.datasets.MNIST(root="./Data",transform=data_transform,train=True)
        test = torchvision.datasets.MNIST(root="./Data", transform=data_transform, train=False)
    
        return torch.utils.data.ConcatDataset([train, test])
    
    def show_tensor_image(image):
        reverse_transforms = transforms.Compose([
            transforms.Lambda(lambda t: (t + 1) / 2),
            transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
            transforms.Lambda(lambda t: t * 255.),
            transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
            transforms.ToPILImage(),
        ])
    
        #Take first image of batch
        if len(image.shape) == 4:
            image = image[0, :, :, :]
        plt.imshow(reverse_transforms(image))
    
    
    
    
    
    class Block(nn.Module):
        def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
            super().__init__()
            self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
            if up:
                self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
                self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
            else:
                self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
                self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
            self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
            self.bnorm1 = nn.BatchNorm2d(out_ch)
            self.bnorm2 = nn.BatchNorm2d(out_ch)
            self.relu  = nn.ReLU()
    
        def forward(self, x, t):
            #print("ttt:",t.shape)
            # First Conv
            h = self.bnorm1(self.relu(self.conv1(x)))
            # Time embedding
            time_emb = self.relu(self.time_mlp(t))
            # Extend last 2 dimensions
            time_emb = time_emb[(..., ) + (None, ) * 2]
            # Add time channel
            h = h + time_emb
            # Second Conv
            h = self.bnorm2(self.relu(self.conv2(h)))
            # Down or Upsample
            return self.transform(h)
    
    
    class SinusoidalPositionEmbeddings(nn.Module):
        def __init__(self, dim):
            super().__init__()
            self.dim = dim
    
        def forward(self, time):
            device = time.device
            half_dim = self.dim // 2
            embeddings = math.log(10000) / (half_dim - 1)
            embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
            embeddings = time[:, None] * embeddings[None, :]
            embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
            # TODO: Double check the ordering here
            return embeddings
    
    
    class SimpleUnet(nn.Module):
        """
        A simplified variant of the Unet architecture.
        """
        def __init__(self):
            super().__init__()
            image_channels =1   #灰度图为1,彩色图为3
            down_channels = (64, 128, 256, 512, 1024)
            up_channels = (1024, 512, 256, 128, 64)
            out_dim = 1   #灰度图为1 ,彩色图为3
            time_emb_dim = 32
    
            # Time embedding
            self.time_mlp = nn.Sequential(
                    SinusoidalPositionEmbeddings(time_emb_dim),
                    nn.Linear(time_emb_dim, time_emb_dim),
                    nn.ReLU()
                )
    
            # Initial projection
            self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
    
            # Downsample
            self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                        time_emb_dim) \
                        for i in range(len(down_channels)-1)])
            # Upsample
            self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                            time_emb_dim, up=True) \
                        for i in range(len(up_channels)-1)])
    
            # Edit: Corrected a bug found by Jakub C (see YouTube comment)
            self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
    
        def forward(self, x, timestep):
            # Embedd time
            t = self.time_mlp(timestep)
            # Initial conv
            x = self.conv0(x)
            # Unet
            residual_inputs = []
            for down in self.downs:
                x = down(x, t)
                residual_inputs.append(x)
            for up in self.ups:
                residual_x = residual_inputs.pop()
                # Add residual x as additional channels
                x = torch.cat((x, residual_x), dim=1)
                x = up(x, t)
            return self.output(x)
    
    
    
    
    def get_loss(model, x_0, t):
        x_noisy, noise = forward_diffusion_sample(x_0, t, device)
        noise_pred = model(x_noisy, t)
        return F.l1_loss(noise, noise_pred)
    
    
    
    @torch.no_grad()
    def sample_timestep(x, t):
        """
        Calls the model to predict the noise in the image and returns
        the denoised image.
        Applies noise to this image, if we are not in the last step yet.
        """
        betas_t = get_index_from_list(betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
            sqrt_one_minus_alphas_cumprod, t, x.shape
        )
    
        sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
        # Call model (current image - noise prediction)
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
        )
    
        posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
        if t == 0:
            # As pointed out by Luis Pereira (see YouTube comment)
            # The t's are offset from the t's in the paper
            return model_mean
        else:
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise
    
    @torch.no_grad()
    def sample_plot_image(IMG_SIZE):
        # Sample noise
        img_size = IMG_SIZE
        img = torch.randn((1, 1, img_size, img_size), device=device)   #生成第T步的图片
        plt.figure(figsize=(15,15))
        plt.axis('off')
        num_images = 10
        stepsize = int(T/num_images)
    
        for i in range(0,T)[::-1]:
            t = torch.full((1,), i, device=device, dtype=torch.long)
            #print("t:",t)
            img = sample_timestep(img, t)
            # Edit: This is to maintain the natural range of the distribution
            img = torch.clamp(img, -1.0, 1.0)
            if i % stepsize == 0:
                plt.subplot(1, num_images, int(i/stepsize)+1)
                plt.title(str(i))
                show_tensor_image(img.detach().cpu())
        plt.show()
    
    
    if __name__ =="__main__":
    
        # Define beta schedule
        T = 300
        betas = linear_beta_schedule(timesteps=T)
    
        # Pre-calculate different terms for closed form
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        # print(alphas_cumprod.shape)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        # print(alphas_cumprod_prev)
        # print(alphas_cumprod_prev.shape)
        sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # print(posterior_variance.shape)
    
    
        IMG_SIZE = 32
        BATCH_SIZE = 16
    
        data = load_transformed_dataset(IMG_SIZE)
        dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    
        model = SimpleUnet()
        print("Num params: ", sum(p.numel() for p in model.parameters()))
    
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        optimizer = Adam(model.parameters(), lr=0.001)
        epochs = 1 # Try more!
    
        for epoch in range(epochs):
            for step, batch in enumerate(dataloader):  #由于batch 是包含标签的所以取batch[0]
                #print(batch[0].shape)
                optimizer.zero_grad()
    
                t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
                loss = get_loss(model, batch[0], t)
                loss.backward()
                optimizer.step()
    
                if epoch % 1 == 0 and step %5== 0:
                    print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
                    sample_plot_image(IMG_SIZE)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295

    参考文献

    https://zhuanlan.zhihu.com/p/630354327](https://zhuanlan.zhihu.com/p/630354327)

    https://blog.csdn.net/weixin_45453121/article/details/131223653

    https://www.cnblogs.com/risejl/p/17448442.html

    https://zhuanlan.zhihu.com/p/569994589?utm_id=0

  • 相关阅读:
    Apache ShardingSphere实战与核心源码剖析
    uniapp开发微信小程序运行报错不知道是哪里有问题,求帮看看
    2022-09-02 mysql/stonedb-现有加载pack数据模块分析及需求推定
    自己动手写一个Golang ORM框架
    什么是同相比例运算放大电路
    提升网站排名到首页的有效方法(学会百度SEO排名的优化知识)
    【网格黑科技】扒一扒你所不知道的Cast-Designer网格黑科技
    云服务的划分IaaS,PaaS,SaaS 的区别
    Qt 框架 6.6版本添加响应式布局,并兼容AArch64 架构
    GitHub上14个屌炸天的Java进阶教程、面试真题项目,建议收藏!
  • 原文地址:https://blog.csdn.net/qq_46018418/article/details/134485773