• 变分自编码器(VAE)公式推导


    论文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]

    本文记录了我在学习 VAE 过程中的一些公式推导和思考。如果你希望从头开始学习 VAE,建议先看一下苏剑林的博客(本文末尾有链接)。

    VAE 的整体框架

    VAE 认为,随机变量 xp(x) 由两个随机过程得到:

    1. 根据先验分布 p(z) 生成隐变量 z
    2. 根据条件分布 p(x|z)z 得到 x

    于是 p(x,z)=p(z)p(x|z) 就是我们所需要的生成模型。

    一种朴素的想法是:先用随机数生成器生成隐变量 z,然后用 p(x|z)z 中生成出(或者说重构出) x,通过最小化重构损失来训练模型。这个想法的问题在于:我们无法找到生成的样本与原始样本之间的对应关系,重构损失算不了,无法训练。

    VAE 的做法是引入后验分布 p(z|x),训练过程变为:

    1. 采样一批原始样本 x
    2. p(z|x) 获得每个样本 x 对应的隐变量 z
    3. p(x|z) 从隐变量 z 中重构出 x,通过最小化重构损失来训练模型。

    从这个角度来看,p(z|x) 相当于编码器p(x|z) 相当于解码器,训练结束后只需要保留解码器 p(x|z) 即可。

    除了重构损失以外,VAE 还有一项 KL 散度损失,希望近似的后验分布 q(z|x) 尽量接近先验分布 p(z),即最小化二者的 KL 散度。

    变分下界的推导

    现有 N 个由分布 P(x;θ) 生成的样本 x(1),,x(N),我们可以使用极大似然估计从这些样本中估计出分布的参数 θ,即

    θ=argmaxθp(x(1);θ)p(x(N);θ)=argmaxθln(p(x(1);θ)p(x(N);θ))=argmaxθi=1nlnp(x(i);θ).

    后验分布 p(z|x)=p(z)p(x|z)p(x)=p(z)p(x|z)zp(x,z)dz 是 intractable 的,因为分母处的边缘分布 p(x) 积不出来。具体来说,联合分布 p(x,z)=p(z)p(x|z) 的表达式非常复杂,zp(x,z)dz 这个积分找不到解析解。

    需要使用变分推断解决后验分布无法计算的问题。我们使用一个形式已知的分布 q(z|x(i);ϕ)近似后验分布 p(z|x(i);θ),于是有

    logp(x(i))=zq(z|x(i))[logq(z|x(i))logp(z|x(i))]dz+zq(z|x(i))[logq(z|x(i))+logp(z|x(i))]dz+logp(x(i))1=zq(z|x(i))logq(z|x(i))p(z|x(i))dz+zq(z|x(i))[logq(z|x(i))+logp(z|x(i))]dz+logp(x(i))zq(z|x(i))dz=KL[q(z|x(i)),p(z|x(i))]+zq(z|x(i))[logq(z|x(i))+logp(z|x(i))]dz+zq(z|x(i))logp(x(i))dz=KL[q(z|x(i)),p(z|x(i))]+zq(z|x(i))[logq(z|x(i))+logp(z|x(i))+logp(x(i))]dz=KL[q(z|x(i)),p(z|x(i))]+zq(z|x(i))[logq(z|x(i))+log(p(z|x(i))p(x(i)))]dz=KL[q(z|x(i)),p(z|x(i))]+zq(z|x(i))[logq(z|x(i))+logp(x(i),z)]dz=KL[q(z|x(i)),p(z|x(i))]+Ezq(z|x(i))[logq(z|x(i))+logp(x(i),z)]=KL[q(z|x(i)),p(z|x(i))]+L(θ,ϕ;x(i))L(θ,ϕ;x(i)).

    利用 KL 散度大于等于 0 这一特性,我们得到了对数似然 logp(x(i)) 的一个下界 L(θ,ϕ;x(i)),于是可以将最大化对数似然改为最大化这个下界。

    这个下界可以进一步写成

    L(θ,ϕ;x(i))=zq(z|x(i))[logq(z|x(i))+logp(x(i),z)]dz=zq(z|x(i))[logq(z|x(i))+log(p(z)p(x(i)|z))]dz=zq(z|x(i))[logq(z|x(i))+logp(z)+logp(x(i)|z)]dz=zq(z|x(i))[logq(z|x(i))logp(z)]dz+zq(z|x(i))logp(x(i)|z)]dz=KL[q(z|x(i)),p(z)]+Ezq(z|x(i))[logp(x(i)|z)].

    其中的第一项是 KL 散度损失,第二项是重构损失。

    KL 散度损失

    使用标准正态分布作为先验分布,即 p(z)=N(z;0,I)

    使用一个由 MLP 的输出来参数化的正态分布作为近似后验分布,即 q(z|x(i);ϕ)=N(z;μ(x(i);ϕ),σ2(x(i);ϕ)I)

    选择正态分布的好处在于 KL 散度的这个积分可以写出解析解,训练时直接按照公式计算即可,无需通过采样的方式来算积分。

    由于我们选择的是各分量独立的多元正态分布,因此只需要推导一元正态分布的情形即可:

    KL[N(z;μ,σ2),N(z;0,1)]=zN(z;μ,σ2)logN(z;μ,σ2)N(z;0,1)dz=zN(z;μ,σ2)log12πσexp((zμ)22σ2)12πexp(z22)dz=zN(z;μ,σ2)log(1σ2exp(12((zμ2)2σ2+z2)))dz=12zN(z;μ,σ2)(logσ2(zμ)2σ2+z2)dz=12(logσ2zN(z;μ,σ2)dz1σ2zN(z;μ,σ2)(zμ)2dz+zN(z;μ,σ2)z2dz)=12(logσ211σ2σ2+μ2+σ2)=12(logσ21+μ2+σ2).

    解释一下倒数第三行的三个积分:

    1. zN(z;μ,σ2)dz 是概率密度函数的积分,也就是 1。
    2. zN(z;μ,σ2)(zμ)2dz 是方差的定义,也就是 σ2
    3. zN(z;μ,σ2)z2dz 是正态分布的二阶矩,结果为 μ2+σ2

    重构损失

    伯努利分布模型

    x 是二值向量时,可以用伯努利分布(两点分布)来建模 p(x|z),即认为向量 x 的每个维度都服从对应的相互独立的伯努利分布。使用一个 MLP 来计算各维度所对应的伯努利分布的参数,第 i 维伯努利分布的参数为 yi=y(z)i,于是有

    p(x|z)=i=1Dyixi(1yi)1xi,

    logp(x|z)=i=1Dxilogyi+(1xi)log(1yi).

    其中 D 表示向量 x 的维度。可见此时最大化 logp(x|z) 等价于最小化交叉熵损失。

    正态分布模型

    x 是实值向量时,可以用正态分布来建模 p(x|z)。使用一个 MLP 来计算正态分布的参数,于是有

    p(x|z)=N(x;μ,σ2I)=i=1DN(xi;μi,σi2)=(i=1D12πσi)exp(i=1D(xiμi)22σi2),

    logp(x|z)=D2log2π12i=1Dlogσi212i=1D(xiμi)2σi2.

    很多时候我们会假设 σi2 是一个常数,于是 MLP 只需要输出均值参数 μ 即可。此时有

    logp(x|z)12i=1D(xiμi)2=12xμ(z)2.

    可见此时最大化 logp(x|z) 等价于最小化 MSE 损失。

    重参数化技巧

    需要使用重参数化技巧解决采样 z 时不可导的问题。解决的思路是先从无参数分布中采样一个 ε,再通过变换得到 z

    N(μ,σ2) 中采样一个 z,相当于先从 N(0,1) 中采样一个 ε,然后令 z=μ+εσ

    相关知识

    技巧,通过取对数把乘除变成加减:

    lnab=lna+lnb, lnab=lnalnb.

    随机变量的函数的期望:

    ExP(x)g(x)=xp(x)g(x)dx,

    利用此公式可以将积分改写成期望的形式,这样就可以用采样的方式计算积分了(蒙特卡罗积分法)。

    条件概率密度的定义:

    pY|X(y|x)=p(x,y)pX(x),

    此处的 p 并不是概率而是概率密度函数,但是这个公式在形式上跟条件概率公式是一样的。

    参考资料

    苏剑林的 VAE 系列博客:

    15 分钟了解变分推理:

  • 相关阅读:
    ffmpeg 滤镜实现不同采样率多音频混音
    Python Opencv实践 - HoG特征计算
    像你这么优秀的测试工程师,怎么就约不到面试呢?
    【Verilog 教程】3.1 Verilog 连续赋值
    基于Spring AOP和CGLIB代理实现引介增强(Introduction Advice)示例
    【Numpy-矩阵库~python】
    【开发心得】Jaxb使用珠玑
    java-php-python-ssm在校大学生健康状况信息管理系统计算机毕业设计
    视频怎么制作成gif动画?这个方法试试看
    Python Web开发二:Django的安装和运行
  • 原文地址:https://www.cnblogs.com/zhb2000/p/variational-autoencoder.html