• 机器学习笔记之变分推断(五)重参数化技巧


    引言

    上一节介绍了随机梯度变分推断(Stochastic Gradient Variational Inference,SGVI)。本节将介绍SGVI求解过程中遇到的问题,并针对为题介绍一种处理方法——重参数化技巧

    回顾:随机梯度变分推断

    由于基于平均场假设的经典变分推断(Classical Variational Inference)的假设条件非常苛刻,基本无法在真实环境中使用

    因此,介绍了随机梯度变分推断,从 P ( Z ∣ X ) P(\mathcal Z \mid \mathcal X) P(ZX)整体角度进行求解。

    SGVI的核心是将分布 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)视为概率模型,既然是概率模型,自然存在描述概率模型的模型参数
    这里定义 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)的模型参数为 ϕ \phi ϕ,将求解 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)的梯度转化为求解模型参数 ϕ \phi ϕ的梯度
    实际上, Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)本身并不是‘隐变量的边缘概率分布’,而是条件概率分布 Q ( Z ∣ X ) \mathcal Q(\mathcal Z \mid \mathcal X) Q(ZX)。只是 X \mathcal X X是观测数据,是已知量,因此省略。
    Q ( Z ) → Q ( Z ∣ ϕ ) L [ Q ( Z ) ] = ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ log ⁡ [ P ( X , Z ) Q ( Z ∣ ϕ ) ] d Z = L ( ϕ ) \mathcal Q(\mathcal Z) \to \mathcal Q(\mathcal Z \mid \phi) \\ L[Q(Z)]=ZϕQ(Zϕ)log[P(X,Z)Q(Zϕ)]dZ=L(ϕ)

    Q(Z)Q(Zϕ)L[Q(Z)]=ZϕQ(Zϕ)log[Q(Zϕ)P(X,Z)]dZ=L(ϕ)
    随后 L ( ϕ ) \mathcal L(\phi) L(ϕ) ϕ \phi ϕ求解梯度,最终化简结果如下:
    ∇ ϕ L ( ϕ ) = E Q ( Z ∣ ϕ ) { ∇ ϕ log ⁡ Q ( Z ∣ ϕ ) ⋅ [ l o g P ( X , Z ) − log ⁡ Q ( Z ∣ ϕ ) ] } \nabla_{\phi} \mathcal L(\phi) = \mathbb E_{\mathcal Q(\mathcal Z \mid \phi)} \left\{ \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) \cdot \left[log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] \right\} ϕL(ϕ)=EQ(Zϕ){ ϕlogQ(Zϕ)[logP(X,Z)logQ(Zϕ)]}
    至此,将梯度结果 ∇ ϕ L ( ϕ ) \nabla_{\phi}\mathcal L(\phi) ϕL(ϕ)表示为期望形式,后续操作可以通过蒙特卡洛采样方法对期望结果进行估计
    假设从 概率模型 P ( Z ∣ ϕ ) P(\mathcal Z \mid \phi) P(Zϕ)中采集了 N N N个样本。即:
    z ( n ) ∼ Q ( Z ∣ ϕ ) ( n = 1 , 2 , ⋯   , N ) z^{(n)} \sim \mathcal Q(\mathcal Z \mid \phi) \quad (n=1,2,\cdots,N) z(n)Q(Zϕ)(n=1,2,,N)
    上述期望使用蒙特卡洛采样方法近似表示为
    ∇ ϕ L ( ϕ ) ≈ 1 N ∑ n = 1 N { ∇ ϕ log ⁡ Q ( z ( n ) ∣ ϕ ) [ log ⁡ P ( X , z ( n ) ) − log ⁡ Q ( z ( n ) ∣ ϕ ) ] } \nabla_{\phi}\mathcal L(\phi) \approx \frac{1}{N} \sum_{n=1}^{N} \left\{\nabla_{\phi} \log \mathcal Q(z^{(n)} \mid \phi) \left[ \log P(\mathcal X,z^{(n)}) - \log \mathcal Q(z^{(n)} \mid \phi)\right]\right\} ϕL(ϕ)N1n=1N{ ϕlogQ(z(n)ϕ)[logP(X,z(n))logQ(z(n)ϕ)]}

    随机梯度变分推断的问题

    公式推导方式本身没有问题,问题在于采样过程中出现的高方差现象(High Variance)。该现象产生的具体原因如下:

    • 观察基于蒙特卡洛方法的近似公式,大括号内主要包含两项,两项均包含 z ( n ) z^{(n)} z(n)。观察第一项
      ∇ ϕ log ⁡ Q ( z ( n ) ∣ ϕ ) \nabla_{\phi} \log \mathcal Q(z^{(n)} \mid\phi) ϕlogQ(z(n)ϕ)

    • 注意,该项并不是求解 log ⁡ Q ( z ( n ) ∣ ϕ ) \log \mathcal Q(z^{(n)} \mid \phi) logQ(z(n)ϕ)的结果,而是该结果的梯度 Q ( z ( n ) ∣ ϕ ) \mathcal Q(z^{(n)} \mid \phi) Q(z(n)ϕ)是一个 描述概率的函数,因此它的值域是 ( 0 , 1 ) (0,1) (0,1)。观察 log ⁡ Q ( z ( n ) ∣ ϕ ) \log \mathcal Q(z^{(n)} \mid \phi) logQ(z(n

  • 相关阅读:
    基于SqlSugar的开发框架循序渐进介绍(26)-- 实现本地上传、FTP上传、阿里云OSS上传三者合一处理
    《Deep Convolution Neural Networks for Twitter Sentiment Analysis》文献研读
    orbslam2 安装过程记录
    TSN网络中的Qbu和802.3br
    倾斜摄影技术构建图扑 WebGIS 智慧展馆
    在KubeSphere启用基于Jenkins的DevOps
    Games101作业0(vscode连接VB虚拟机)
    Highcharts JS 10.3.1 开心没水印
    2023年工单管理系统排行榜单
    docker、docker-compose 下安装elasticsearch、IK分词器
  • 原文地址:https://blog.csdn.net/qq_34758157/article/details/126936729