优化目标:
L
I
W
A
E
(
x
1
:
M
)
=
E
z
1
:
K
∼
q
Φ
(
z
∣
x
1
:
M
)
[
log
∑
k
=
1
K
1
K
p
Θ
(
z
k
,
x
1
:
M
)
q
Φ
(
z
k
∣
x
1
:
M
)
]
(
1
)
\mathcal{L}_{\mathrm{IWAE}}\left(\boldsymbol{x}_{1: M}\right)=\mathbb{E}_{\boldsymbol{z}^{1: K} \sim q_{\Phi}\left(\boldsymbol{z} \mid \boldsymbol{x}_{1: M}\right)}\left[\log \sum_{k=1}^K \frac{1}{K} \frac{p_{\Theta}\left(\boldsymbol{z}^k, \boldsymbol{x}_{1: M}\right)}{q_{\Phi}\left(\boldsymbol{z}^k \mid \boldsymbol{x}_{1: M}\right)}\right] \quad\quad\quad(1)
LIWAE(x1:M)=Ez1:K∼qΦ(z∣x1:M)[logk=1∑KK1qΦ(zk∣x1:M)pΘ(zk,x1:M)](1)
这里
p
Θ
(
z
,
x
1
:
M
)
=
p
(
z
)
∏
m
=
1
M
p
θ
m
(
x
m
∣
z
)
p_{\Theta}\left(\boldsymbol{z}, \boldsymbol{x}_{1: M}\right)=p(\boldsymbol{z}) \prod_{m=1}^M p_{\theta_m}\left(\boldsymbol{x}_m \mid \boldsymbol{z}\right)
pΘ(z,x1:M)=p(z)∏m=1Mpθm(xm∣z)
后验分布由推理网络近似得到
q
Φ
(
z
k
∣
x
1
:
M
)
q_{\Phi}\left(\boldsymbol{z}^k \mid \boldsymbol{x}_{1: M}\right)
qΦ(zk∣x1:M)
采用MoE方法进行多模态融合的优化目标:
L
I
W
A
E
M
o
E
(
x
1
:
M
)
=
1
M
∑
m
=
1
M
E
z
m
1
:
K
∼
q
ϕ
m
(
z
∣
x
m
)
[
log
1
K
∑
k
=
1
K
p
Θ
(
z
m
k
,
x
1
:
M
)
q
Φ
(
z
m
k
∣
x
1
:
M
)
]
(
2
)
\mathcal{L}_{\mathrm{IWAE}}^{\mathrm{MoE}}\left(\boldsymbol{x}_{1: M}\right)=\frac{1}{M} \sum_{m=1}^M \mathbb{E}_{\boldsymbol{z}_m^{1: K} \sim q_{\phi_m}\left(\boldsymbol{z} \mid \boldsymbol{x}_m\right)}\left[\log \frac{1}{K} \sum_{k=1}^K \frac{p_{\Theta}\left(\boldsymbol{z}_m^k, \boldsymbol{x}_{1: M}\right)}{q_{\Phi}\left(\boldsymbol{z}_m^k \mid \boldsymbol{x}_{1: M}\right)}\right] \quad\quad\quad(2)
LIWAEMoE(x1:M)=M1m=1∑MEzm1:K∼qϕm(z∣xm)[logK1k=1∑KqΦ(zmk∣x1:M)pΘ(zmk,x1:M)](2)
根据MoE方法近似的后验分布为: q Φ ( z ∣ x 1 : M ) = ∑ m α m ⋅ q ϕ m ( z ∣ x m ) q_{\Phi}\left(\boldsymbol{z} \mid \boldsymbol{x}_{1: M}\right)=\sum_m \alpha_m \cdot q_{\phi_m}\left(\boldsymbol{z} \mid \boldsymbol{x}_m\right) qΦ(z∣x1:M)=∑mαm⋅qϕm(z∣xm),这里 α = 1 M \alpha = \frac{1}{M} α=M1
计算IWAE的主体代码:

在for循环里面一行行的分析,以r=0为例:
最后运算:
l1 = log_mean_exp(lw, dim=0)
就可以得到: p ( z 1 ) ⋅ ( x 1 ∣ z 1 ) ⋅ ( x 2 ∣ z 1 ) q ( z 1 ∣ x 1 ) + q ( z 1 ∣ x 2 ) ( 3 ) \cfrac{p(z_1)\cdotp(x_1|z_1)\cdotp(x_2|z_1)}{q(z_1|x_1) + q(z_1|x_2)} \quad\quad\quad(3) q(z1∣x1)+q(z1∣x2)p(z1)⋅(x1∣z1)⋅(x2∣z1)(3)
这个结果就是上述公式2中m=1时的结果,这样一行行的分析就可以很好的理解上述代码是如何实现IWAE多模态变体的。
其中log_mean_exp的代码:
def log_mean_exp(value, dim=0, keepdim=False):
return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim))
log_mean_exp和torch.logsumexp的区别就是字面意思,前面取平均,后者求和
