由于以下推导需要用到
K
L
\rm KL
KL散度,这里先简单介绍一下。
K
L
\rm KL
KL散度一般用于度量两个概率分布函数之间的“距离”,其定义如下:
K
L
[
P
(
X
)
∣
∣
Q
(
X
)
]
=
∑
x
∈
X
[
P
(
x
)
log
P
(
x
)
Q
(
x
)
]
=
E
x
∼
P
(
x
)
[
log
P
(
x
)
Q
(
x
)
]
KL\big[P(X)||Q(X)\big]=\sum_{x\in X}\Big[P(x)\log\frac{P(x)}{Q(x)}\Big]=E_{x\sim P(x)}\Big[\log\frac{P(x)}{Q(x)}\Big]
KL[P(X)∣∣Q(X)]=∑x∈X[P(x)logQ(x)P(x)]=Ex∼P(x)[logQ(x)P(x)]
这里
P
(
X
)
P(X)
P(X)和
Q
(
X
)
Q(X)
Q(X)是两个概率分布函数,可以看到对于离散型随机变量,
K
L
\rm KL
KL散度对
x
x
x进行求和;对于连续型随机变量,
K
L
\rm KL
KL散度对
x
x
x进行积分(期望)。
高斯分布的
K
L
\rm KL
KL散度
对于两个单一变量的高斯分布
p
∼
N
(
μ
1
,
σ
1
2
)
p\sim\mathcal{N}(\mu_1, \sigma_1^2)
p∼N(μ1,σ12)和
q
∼
N
(
μ
2
,
σ
2
2
)
q\sim\mathcal{N}(\mu_2,\sigma_2^2)
q∼N(μ2,σ22)而言,它们的KL散度为
K
L
(
p
,
q
)
=
log
σ
2
σ
1
+
σ
1
2
+
(
μ
1
−
μ
2
)
2
2
σ
2
2
−
1
2
KL(p,q)=\log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2}
KL(p,q)=logσ1σ2+2σ22σ12+(μ1−μ2)2−21
下方是论文中给出的后向过程
x
t
−
1
\mathbf{x}_{t-1}
xt−1的分布,其方差为常数。
p
θ
(
x
0
:
T
)
=
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
,
p
θ
(
x
t
−
1
∣
x
t
)
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
∑
θ
(
x
t
,
t
)
)
p_{\theta}(\mathbf{x}_{0:T})=p(\mathbf{x}_T)\prod_{t=1}^T p_{\theta}(\mathbf{x}_{t-1}\mid\mathbf{x}_t),\qquad p_{\theta}(\mathbf{x}_{t-1}\mid\mathbf{x}_t)=\mathcal{N}(\mathbf{x}_{t-1};\mu_{\theta}(\mathbf{x}_t,t),\sum_{\theta}(\mathbf{x}_t,t))
pθ(x0:T)=p(xT)∏t=1Tpθ(xt−1∣xt),pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),∑θ(xt,t))
推出扩散模型目标数据分布的似然函数,推出似然函数后才能优化模型。
p
θ
(
x
0
)
p_{\theta}(\mathbf{x}_0)
pθ(x0)为目标数据分布,其对数似然下界越大,那么对数似然越大。为了方便推导,这里用其负对数似然
−
log
p
θ
(
x
0
)
-\log p_{\theta}(\mathbf{x}_0)
−logpθ(x0)推导,其上界越小,负对数似然越小,相对应其对数似然越大。
−
log
p
θ
(
x
0
)
≤
−
log
p
θ
(
x
0
)
+
D
K
L
(
q
(
x
1
:
T
∣
x
0
)
∥
p
θ
(
x
1
:
T
∣
x
0
)
)
(
1
)
=
−
log
p
θ
(
x
0
)
+
E
x
1
:
T
∼
q
(
x
1
:
T
∣
x
0
)
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
/
p
θ
(
x
0
)
]
(
2
)
=
−
log
p
θ
(
x
0
)
+
E
q
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
+
log
p
θ
(
x
0
)
]
(
3
)
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
]
(
4
)
然后我们将不等式左边的
−
log
p
θ
(
x
0
)
-\log p_{\theta}(\mathbf{x}_0)
−logpθ(x0)套上一个关于分布
q
(
x
0
)
q(\mathbf{x}_0)
q(x0)的期望,得到
−
E
q
(
x
0
)
log
p
θ
(
x
0
)
-\Bbb{E}_{q(\mathbf{x}_0)}\log p_{\theta}(\mathbf{x}_0)
−Eq(x0)logpθ(x0)(交叉熵,也即loss);相应的,不等式右边也要加上一个
x
0
\mathbf{x}_0
x0,则由
E
q
(
x
1
:
T
∣
x
0
)
\Bbb{E}_{q(\mathbf{x}_{1:T}\mid\mathbf{x}_0)}
Eq(x1:T∣x0)变为
E
q
(
x
0
:
T
)
\Bbb{E}_{q(\mathbf{x}_{0:T})}
Eq(x0:T)。如果我们想最小化loss,也就是最小化
E
q
(
x
0
:
T
)
\Bbb{E}_{q(\mathbf{x}_{0:T})}
Eq(x0:T)。
L
e
t
L
V
L
B
=
E
q
(
x
0
:
T
)
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
]
≥
−
E
q
(
x
0
)
log
p
θ
(
x
0
)
\rm Let\text{ }\it L_{\rm VLB} \it = \Bbb{E}_{q(\mathbf{x}_{0:T})}\Big[\log\frac{q(\mathbf{x}_{1:T}\mid\mathbf{x}_0)}{p_{\theta}(\mathbf{x}_{0:T})}\Big]\geq -\Bbb{E}_{q(\mathbf{x}_0)}\log p_{\theta}(\mathbf{x}_0)
Let LVLB=Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)]≥−Eq(x0)logpθ(x0)
L
V
L
B
=
E
q
(
x
0
:
T
)
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
]
(
1
)
=
E
[
log
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
]
(
2
)
=
E
q
[
−
log
p
θ
(
x
T
)
+
∑
t
=
1
T
log
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
t
−
1
∣
x
t
)
]
(
3
)
=
E
q
[
−
log
p
θ
(
x
T
)
+
∑
t
=
2
T
log
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
t
−
1
∣
x
t
)
+
log
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
(
4
)
=
E
q
[
−
log
p
θ
(
x
T
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
⋅
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
)
+
log
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
(
5
)
=
E
q
[
−
log
p
θ
(
x
T
)
+
∑
t
=
2
T
log
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
+
∑
t
=
2
T
log
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
+
log
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
(
6
)
=
E
q
[
−
log
p
θ
(
x
T
)
+
∑
t
=
2
T
log
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
+
log
q
(
x
T
∣
x
0
)
q
(
x
1
∣
x
0
)
+
log
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
(
7
)
=
E
q
[
log
q
(
x
T
∣
x
0
)
p
θ
(
x
T
)
+
∑
t
=
2
T
log
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
−
log
p
θ
(
x
0
∣
x
1
)
]
(
8
)
=
E
q
[
D
K
L
(
q
(
x
T
∣
x
0
)
∥
p
θ
(
x
T
)
)
⏟
L
T
+
∑
t
=
2
T
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
⏟
L
t
−
1
−
log
p
θ
(
x
0
∣
x
1
)
⏟
L
0
]
(
9
)
在论文中,作者将分布
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(\mathbf{x}_{t-1}\mid\mathbf{x}_t)
pθ(xt−1∣xt)的方差看作与
β
\beta
β相关的常数,那么可训练的参数就存在于其均值当中。在
L
t
−
1
L_{t-1}
Lt−1中,
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(\mathbf{x}_{t-1}\mid\mathbf{x}_t,\mathbf{x}_0)
q(xt−1∣xt,x0)是一个高斯分布,其方差和均值我们已经在之前后向过程推导中求出,均值为
μ
~
t
(
x
t
)
\tilde{\mu}_t(\mathbf{x}_t)
μ~t(xt),方差为和
β
t
\beta_t
βt有关的常数。而
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(\mathbf{x}_{t-1}\mid\mathbf{x}_t)
pθ(xt−1∣xt)也是我们假设的高斯分布,它的方差也是常数,均值为
μ
θ
(
x
t
,
t
)
\mu_{\theta}(\mathbf{x}_t,t)
μθ(xt,t),所以参数只在
μ
θ
\mu_{\theta}
μθ当中。对于这两个高斯分布,我们可以运用高斯分布的
K
L
\rm KL
KL散度公式,其中的方差我们可以不考虑。则我们可以得到如下的式子:
L
t
−
1
=
E
q
[
1
2
σ
t
2
∥
μ
~
t
(
x
t
,
x
0
)
−
μ
θ
(
x
t
,
t
)
∥
2
]
+
C
L_{t-1}=\Bbb{E}_q \Big[\frac{1}{2\sigma_t^2} \lVert \tilde{\mu}_t(\mathbf{x}_t,\mathbf{x}_0)-\mu_{\theta}(\mathbf{x}_t,t)\rVert^2 \Big]+C
Lt−1=Eq[2σt21∥μ~t(xt,x0)−μθ(xt,t)∥2]+C
由这个式子,我们优化目标就很明确了,我们要优化
μ
θ
\mu_{\theta}
μθ,让其无线逼近于
μ
~
t
\tilde{\mu}_t
μ~t,这样才能使
L
t
−
1
L_{t-1}
Lt−1最小。首先我们将
μ
~
t
(
x
t
)
\tilde{\mu}_t(\mathbf{x}_t)
μ~t(xt)代入上述的式子中,原式中的
z
~
t
\tilde{z}_t
z~t用
ϵ
\epsilon
ϵ来表示,
x
t
\mathbf{x}_t
xt用
x
t
(
x
0
,
ϵ
)
\mathbf{x}_t(\mathbf{x}_0,\epsilon)
xt(x0,ϵ)替换,就能得到下方第二个等号的式子。
L
t
−
1
−
C
=
E
x
0
,
ϵ
[
1
2
σ
t
2
∥
μ
~
t
(
x
t
(
x
0
,
ϵ
)
,
1
α
ˉ
t
(
x
t
(
x
0
,
ϵ
)
−
1
−
α
ˉ
t
ϵ
)
)
−
μ
θ
(
x
t
(
x
0
,
ϵ
)
,
t
)
∥
2
]
=
E
x
0
,
ϵ
[
1
2
σ
t
2
∥
1
α
t
(
x
t
(
x
0
,
ϵ
)
−
β
t
1
−
α
ˉ
t
ϵ
)
−
μ
θ
(
x
t
(
x
0
,
ϵ
)
,
t
)
∥
2
]
这里我们的
x
t
\mathbf{x}_t
xt是已知的,那么为了使
L
t
−
1
L_{t-1}
Lt−1最小,我们可以将
μ
θ
(
x
t
,
t
)
\mu_{\theta}(\mathbf{x}_t,t)
μθ(xt,t)表示为
μ
~
t
\tilde{\mu}_t
μ~t的一个波动,其中的
ϵ
\epsilon
ϵ是未知的,则我们可以训练一个网络来预测
ϵ
\epsilon
ϵ。
μ
θ
(
x
t
,
t
)
=
μ
~
t
(
x
t
,
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
ϵ
θ
(
x
t
)
)
)
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
)
\mu_{\theta}(\mathbf{x}_t,t)=\tilde{\mu}_t\Big(\mathbf{x}_t,\frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_{\theta}(\mathbf{x}_t)}) \Big)=\frac{1}{\sqrt{\alpha_t}}\Big(\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_{\theta}(\mathbf{x}_t,t) \Big)
μθ(xt,t)=μ~t(xt,αˉt1(xt−1−αˉtϵθ(xt)))=αt1(xt−1−αˉtβtϵθ(xt,t))
于是
L
t
−
1
L_{t-1}
Lt−1可以简化为如下形式
E
x
0
,
ϵ
[
β
t
2
2
σ
t
2
α
t
(
1
−
α
ˉ
t
)
∥
ϵ
−
ϵ
θ
(
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
,
t
)
∥
2
]
\Bbb{E}_{\mathbf{x_0},\epsilon}\Big[ \frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)}\lVert \epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\rVert^2\Big]
Ex0,ϵ[2σt2αt(1−αˉt)βt2∥ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∥2]
作者又发现,将系数丢掉,训练更加稳定质量更好,于是就得到了下方的
L
s
i
m
p
l
e
L_{\rm simple}
Lsimple
L
s
i
m
p
l
e
(
θ
)
:
=
E
t
,
x
0
,
ϵ
[
∥
ϵ
−
ϵ
θ
(
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
,
t
)
∥
2
]
L_{\rm simple}(\theta):=\Bbb{E}_{t,\mathbf{x_0},\epsilon}\Big[ \lVert \epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\rVert^2\Big]
Lsimple(θ):=Et,x0,ϵ[∥ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∥2]