上一节介绍了随机梯度变分推断(Stochastic Gradient Variational Inference,SGVI)。本节将介绍SGVI求解过程中遇到的问题,并针对为题介绍一种处理方法——重参数化技巧。
由于基于平均场假设的经典变分推断(Classical Variational Inference)的假设条件非常苛刻,基本无法在真实环境中使用。
因此,介绍了随机梯度变分推断,从 P ( Z ∣ X ) P(\mathcal Z \mid \mathcal X) P(Z∣X)整体角度进行求解。
SGVI的核心是将分布 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)视为概率模型,既然是概率模型,自然存在描述概率模型的模型参数。
这里定义 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)的模型参数为 ϕ \phi ϕ,将求解 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)的梯度转化为求解模型参数 ϕ \phi ϕ的梯度。
实际上,
Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)本身并不是‘隐变量的边缘概率分布’,而是条件概率分布
Q ( Z ∣ X ) \mathcal Q(\mathcal Z \mid \mathcal X) Q(Z∣X)。只是
X \mathcal X X是观测数据,是已知量,因此省略。
Q ( Z ) → Q ( Z ∣ ϕ ) L [ Q ( Z ) ] = ∫ Z ∣ ϕ Q ( Z ∣ ϕ ) ⋅ log [ P ( X , Z ) Q ( Z ∣ ϕ ) ] d Z = L ( ϕ ) \mathcal Q(\mathcal Z) \to \mathcal Q(\mathcal Z \mid \phi) \\ L[Q(Z)]=∫Z∣ϕQ(Z∣ϕ)⋅log[P(X,Z)Q(Z∣ϕ)]dZ=L(ϕ)
随后 L ( ϕ ) \mathcal L(\phi) L(ϕ)对 ϕ \phi ϕ求解梯度,最终化简结果如下:
∇ ϕ L ( ϕ ) = E Q ( Z ∣ ϕ ) { ∇ ϕ log Q ( Z ∣ ϕ ) ⋅ [ l o g P ( X , Z ) − log Q ( Z ∣ ϕ ) ] } \nabla_{\phi} \mathcal L(\phi) = \mathbb E_{\mathcal Q(\mathcal Z \mid \phi)} \left\{ \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) \cdot \left[log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] \right\} ∇ϕL(ϕ)=EQ(Z∣ϕ){
∇ϕlogQ(Z∣ϕ)⋅[logP(X,Z)−logQ(Z∣ϕ)]}
至此,将梯度结果 ∇ ϕ L ( ϕ ) \nabla_{\phi}\mathcal L(\phi) ∇ϕL(ϕ)表示为期望形式,后续操作可以通过蒙特卡洛采样方法对期望结果进行估计。
假设从 概率模型 P ( Z ∣ ϕ ) P(\mathcal Z \mid \phi) P(Z∣ϕ)中采集了 N N N个样本。即:
z ( n ) ∼ Q ( Z ∣ ϕ ) ( n = 1 , 2 , ⋯ , N ) z^{(n)} \sim \mathcal Q(\mathcal Z \mid \phi) \quad (n=1,2,\cdots,N) z(n)∼Q(Z∣ϕ)(n=1,2,⋯,N)
上述期望使用蒙特卡洛采样方法近似表示为:
∇ ϕ L ( ϕ ) ≈ 1 N ∑ n = 1 N { ∇ ϕ log Q ( z ( n ) ∣ ϕ ) [ log P ( X , z ( n ) ) − log Q ( z ( n ) ∣ ϕ ) ] } \nabla_{\phi}\mathcal L(\phi) \approx \frac{1}{N} \sum_{n=1}^{N} \left\{\nabla_{\phi} \log \mathcal Q(z^{(n)} \mid \phi) \left[ \log P(\mathcal X,z^{(n)}) - \log \mathcal Q(z^{(n)} \mid \phi)\right]\right\} ∇ϕL(ϕ)≈N1n=1∑N{
∇ϕlogQ(z(n)∣ϕ)[logP(X,z(n))−logQ(z(n)∣ϕ)]}
公式推导方式本身没有问题,问题在于采样过程中出现的高方差现象(High Variance)。该现象产生的具体原因如下:
观察基于蒙特卡洛方法的近似公式,大括号内主要包含两项,两项均包含 z ( n ) z^{(n)} z(n)。观察第一项:
∇ ϕ log Q ( z ( n ) ∣ ϕ ) \nabla_{\phi} \log \mathcal Q(z^{(n)} \mid\phi) ∇ϕlogQ(z(n)∣ϕ)
注意,该项并不是求解 log Q ( z ( n ) ∣ ϕ ) \log \mathcal Q(z^{(n)} \mid \phi) logQ(z(n)∣ϕ)的结果,而是该结果的梯度。 Q ( z ( n ) ∣ ϕ ) \mathcal Q(z^{(n)} \mid \phi) Q(z(n)∣ϕ)是一个 描述概率的函数,因此它的值域是 ( 0 , 1 ) (0,1) (0,1)。观察 log Q ( z ( n ) ∣ ϕ ) \log \mathcal Q(z^{(n)} \mid \phi) logQ(z(n