• 生成模型的两大代表:VAE和GAN


    生成模型

    给定数据集,希望生成模型产生与训练集同分布的新样本。对于训练数据服从pdata(x);对于产生样本服从pmodel(x)。希望学到一个模型pmodel(x)pdata(x)尽可能接近。

    这也是无监督学习中的一个核心问题——密度估计问题。有两种典型的思路:

    1. 显式的密度估计:显式得定义并求解分布pmodel(x),如VAE。
    2. 隐式的密度估计:学习一个模型pmodel(x),而无需显式定义它,如GAN。

    VAE

    AE

    首先介绍下自编码器(Auto Encoder, AE),它将输入的图像X通过编码器encoder编码为一个隐向量(bottleneck)Z,然后再通过解码器decoder解码为重构图像X',它将自己编码压缩再还原故称自编码。结构如下图所示:

    以手写数字数据集MNIST为例,输入图像大小为28x28,通道数为1,定义隐向量的维度(latent_dim)为1 x N,N=20。经过编码器编码为一个长度为20的向量,再通过解码器解码为28x28大小的图像。将生成图像X'与原始图像X进行对比,计算重构误差,通过最小化误差优化模型参数:

    Loss=distance(X,X)

    一般distance距离函数选择均方误差(Mean Square Error, MSE)。AE与PCA作用相同,通过压缩数据实现降维,还能把降维后的数据进行重构生成图像,但PCA的通过计算特征值实现线性变换,而AE则是非线性。

    VAE

    如果中间的隐向量的每一分量取值不是直接来自Encoder,而是在一个分布上进行采样,那么就是变分自编码器(Variational Auto Encoder,VAE),结构如下图所示:

    还是上面的例子,这里的Z维度还是1 x 20,但是每一分量不是直接来自Encoder,而是在一个分布上进行采样计算,一般来说分布选择正态分布(当然也可以是其他分布)。每个正态分布的μσ由Encoder的神经网络计算而来。关于Z上每一分量的计算,这里,ϵ从噪声分布中随机采样得到。

    z(i,l)=μ(i)+σ(i)ϵ(l) and ϵ(l)N(0,I)

    在Encoder的过程中给定x得到z就是计算后验概率qϕ(z|x),学习得到的z为先验分布pθ(z),Decoder部分根据z计算x的过程就是似然估计pθ(x|z),训练的目的也是最大化似然估计(给出了z尽可能得还原为x)。

    边缘似然度pθ(x)=pθ(z)pθ(x|z)dz,边缘似然度又是每个数据点的边缘似然之和组成:logpθ(x(1),,x(N))=i=1Nlogpθ(x(i)),可以被重写为:

    logpθ(x(i))=DKL(qϕ(z|x(i))||pθ(z|x(i)))+L(θ,ϕ;x(i))

    pθ(z|x(i))通常被假设为标准正态分布,等式右边第二项称为边缘似然估计的下界,可以写为:

    logpθ(x(i))L(θ,ϕ;x(i))=Ezqϕ(z|x)[logqϕ(z|x)+logpθ(x|z)]

    得到损失函数:

    L(θ,ϕ;x(i))=DKL(qϕ(z|x(i))||pθ(z))+Ezqϕ(z|x(i))[logpθ(x(i)|z)]

    GAN

    生成对抗网络(Generative Adversarial Nets, GAN)需要同时训练两个模型:生成器(Generator, G)和判别器(Discriminator, D)。生成器的目标是生成与训练集同分布的样本,而判别器的目标是区分生成器生成的样本和训练集中的样本,两者相互博弈最后达到平衡(纳什均衡),生成器能够以假乱真,判别器无法区分真假。

    生成器和判别器最简单的应用就是分别设置为两个MLP。为了让生成器在数据x学习分布pg,定义一个噪声分布pz(z),然后使用生成器G(z;θg)将噪声映射为生成数据x'(θg是生成器模型参数)。同样定义判别器D(x;θd),输出为标量表示概率,代表输入的x来自数据还是pg。训练D时,以最大化分类训练样例还是G生成样本的概率准确性为目的;同时训练G以最小化log(1D(G(z)))为目的,两者互为博弈的双方,定义它们的最大最小博弈的价值函数V(G,D)

    minGmaxDV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]

    可以得到生成器损失函数:LG=1mi=1mlog(1D(G(z(i))))

    判别器损失函数:LD=1mi=1m[logD(x(i))+log(1D(G(z(i))))]

    极端情况下如果D很完美,D(x)=1,D(G(z))=0,最后两项结果都为0,但如果存在误分类,由于log两项结果会变为负数。随着G的输出越来越像x导致D误判,价值函数V也会随之变小。

    计算它们的期望(Expf(x)=xp(x)f(x)dx):

    V(G,D)=xpdata(x)logD(x)dx+zpz(z)log(1D(G(z)))dz=xpdata(x)logD(x)+pg(x)log(1D(x))dx

    当D取到最优解时,上面的最大最小博弈价值函数V(G,D)可以写为:

    C(G)=maxDV(G,D)=Expdata[logpdata(x)pdata(x)+pg(x)]+Expg[logpg(x)pdata(x)+pg(x)]

    pg=pdata,取到log4,上式可以写成KL散度的形式:

    C(G)=log4+KL(pdata||pdata+pg2)+KL(pg||pdata+pg2)

    pg=pdata时,G取最小值也就是最优解。对于对称的KL散度,可以写成JS散度的形式:

    C(G)=2JS(pdata||pg)log4

    参考文献

    1. PyTorch-VAE-vanilla_vae.py
    2. Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
    3. DALL·E 2(内含扩散模型介绍)【论文精读】
    4. Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems 27 (2014).
    5. 【概率论】先验概率、联合概率、条件概率、后验概率、全概率、贝叶斯公式
    6. 机器学习方法—优雅的模型(一):变分自编码器(VAE)
    7. GAN论文逐段精读【论文精读】
  • 相关阅读:
    vue3 globalData 的使用方法
    【RocketMQ】主从同步实现原理
    【音视频】C语言基础快速复习
    ToBeWritten之基于ATT&CK的模拟攻击:闭环的防御与安全运营
    运行 Triton 示例
    使用Fontcreator字体制作软件及字体设计学习
    为什么说快手在本地生活市场将大有所为
    创建自己的日期错误异常类
    《计算机网络 - 自顶向下方法》阅读笔记
    OpenCV笔记整理【视频处理】
  • 原文地址:https://www.cnblogs.com/zh-jp/p/17893306.html