【PRML】如何简单易懂地理解变分推断(Variational Inference)

写在前面

照搬了一部分去年《概率论与数理统计》课程自己的期中论文,当时啃了大名鼎鼎的经典著作 《模式识别与机器学习》(Pattern Recognition and Machine Learning,PRML) 大概三四章,没按顺序,并且啃地也较快,后面也没时间填坑。现在继续来填概率论的大坑,不过应该也不会特别细,楼主太菜了,能理解就理解,至于一些需要记忆的一般偏向于用到了现查。

一、为什么需要变分推断?

先来简单地回顾一下贝叶斯推断,即已知 xx ,我们希望知道潜在变量 zz 的概率分布,即所谓的后验概率分布 p(zx)p(z | x)

p(zx)=p(xz)p(z)p(x)=p(x,z)p(x,z)dzp(z | x) = \frac{p(x | z)p(z)}{p(x)} = \frac{p(x, z)}{\int p(x, z) dz}

其中:

  • xx :观察数据 ;
  • zz :潜在变量(latent variable);
  • p(x,z)=p(xz)p(z)p(x, z) = p(x|z)p(z) :联合分布;
  • p(zx)p(z|x) :后验分布;
  • p(x)=p(x,z)dzp(x) = \int p(x,z)dz :边缘似然;

实际情况中,往往由于变量维度较高,或者被积函数过于复杂,导致精确计算边缘似然的代价过高。这时候就需要进行近似推断,典型代表之一就是变分推断

二、变分推断的核心思想

用一个可调的、简化的分布 q(z)q(z)逼近真实后验 p(zx)p(z | x),使它们越像越好。计算后验概率的问题转换为一个最优化问题,就是让 q(z)q(z) 尽可能接近 p(zx)p(z|x) ,而衡量两个分布之间的差异性,就要用到 KL 散度了。

三、何为变分?

变分听起来是一个很高深的名词,实则也确实一点都不简单(对笔者而言)。

我们可以把函数 y(x)y(x) 看成⼀个运算符。对于任意输⼊ xx ,这个运算符都能返回⼀个输出 yy ,或者说 yy 是 关于 xx 的一个实数算子,将实数 xx 映射到实数 y(x)y(x)

类比一下,可以定义一个泛函 F[y]F[y] ,它是一个运算符,它以 y(x)y(x) 作为输入,输出 FF 。这样讲可能比较抽象,所以举几个例子:

  1. 泛函表示二维空间中一条曲线的长度,曲线的轨迹根据函数来定义;
  2. 在 ML 领域,常见的就是随机变量 xx 的熵 H(x)H(x)

两个点之间有无数条路径,每一条路径都是一个函数,当函数变化了一点而导致泛函值变化了多少,这就是变分。传统微积分中,希望找到一个 xx 使得 y(x)y(x) 取到最大值或最小值。同样,变分法中,希望找到一个函数 y(x)y(x) 最大化或最小化泛函 F[y]F[y] 。变分法可以⽤来说明两点之间的最短路径是⼀条直线,或者最⼤熵分布是⾼斯分布。

四、变分推断

4.1 推断目标

上文已经定义了 xxzz ,这里再来明确一下,在 ML 领域,xx 表示观测变量(也叫输入变量、证据变量等,叫法无所谓),zz 表示隐变量(就是希望推断的 Label)。放在回归模型中,分别就是输入值和预测值;放在分类问题中,就是输入的图像等和输出等类别。

重新回顾一下我们的目标:

MinimizeDKL(q(z)p(zx))Minimize\quad D_{KL}(q(z)||p(z|x))

注意这里,我们使用了反向 KL 散度(即 DKL(qp)D_{KL}(q||p) ,用前面的分布 qq 近似后面的分布 pp ),KL 散度不具有对称性,这与一般的正向 KL 散度(即 DKL(pq)D_{KL}(p||q) ,用后面的分布 qq 近似前面的分布 pp )是不同的。我们为什么选择反向 KL 散度呢?

4.2 为什么选择反向 KL 散度?

对于正向 KL 散度,即 DKL(p(zx)q(z))=zp(zx)logp(zx)q(z)D_{KL}(p(z|x)||q(z))=\sum_z{p(z|x)\log{\frac{p(z|x)}{q(z)}}} 中,首先你会发现一个很大的问题就是,对数外面的后验分布 p(zx)p(z|x) 并没有办法像 4.3 节中一样化掉,无法计算!

问题不止如此,正向 KL 散度中,如果 p(zx)>0p(z|x)>0 ,如果 q(z)0q(z)\rightarrow 0 ,那么 DKL(pq)+D_{KL}(p||q)\rightarrow +\infty ,考虑如果 p(zx)p(z|x) 是一个双峰的分布,我们必须让 q(z)q(z) 同时逼近所有的峰,如果只贴近一个峰,那就会导致另一个峰的部分 q(z)0q(z)\rightarrow 0 ,KL 散度无穷大,这并不是我们所希望的。

那我们再来看反向 KL 散度解决了什么问题?首先,它是可计算的,在 4.3 节对 ELBO 的推导中天然出现;其次,反向 KL 散度有 “零强迫” 性质,即当 q(z)>0q(z) > 0 的地方,如果 p(zx)p(z|x) 很小,就会被强烈惩罚(因为 log0=\log 0 = -\infty),所以它会逼近 p(zx)p(z|x)高概率区域,面对多峰的分布,一般会收敛于主峰而忽略次要峰。

至于为什么倾向于选择单峰而非多峰,大概是因为其容易收敛,且一般不会出现去靠近多个峰而导致每个峰的近似效果都不太好。

4.3 边缘似然下界

p(zx)p(z|x) 并不容易求出,所以对反向 KL 散度进行变换:

DKL(q(z)p(zx))=zq(z)logq(z)p(zx)=Eq(z)[logq(z)p(zx)]=Eq(z)[logq(z)p(x)p(x,z)]=Eq(z)[logp(x)logp(x,z)q(z)]=logp(x)Eq(z)[logp(x,z)q(z)]\begin{align*} D_{KL}(q(z)\,\|\,p(z|x)) &= \sum_z q(z) \log \frac{q(z)}{p(z|x)} \\ &= \mathbb{E}_{q(z)} \left[ \log \frac{q(z)}{p(z|x)} \right] \\ &= \mathbb{E}_{q(z)} \left[ \log \frac{q(z)\cdot p(x)}{p(x,z)} \right] \\ &= \mathbb{E}_{q(z)} \left[ \log p(x) - \log \frac{p(x,z)}{q(z)} \right] \\ &= \log p(x) - \mathbb{E}_{q(z)} \left[ \log \frac{p(x,z)}{q(z)} \right] \end{align*}

由于输入变量的分布已知,即 logp(x)\log{p(x)} 为常量,最小化 KL 散度的目标转换为:

maxL(q)maxEq(z)[logp(x,z)q(z)]maxEq(z)[logp(x,z)logq(z)]\max{​\mathcal{L}(q)}\Leftrightarrow\max{\mathbb{E}_{q(z)} \left[ \log \frac{p(x,z)}{q(z)}\right]}\Leftrightarrow \max{\mathbb{E}_{q(z)}\left[\log{p(x,z)-\log{q(z)}}\right]}

L(q)\mathcal{L}(q) 又被称为边缘似然下界/证据下界(Evidence Lower Bound,ELBO),为什么称其为下界呢?

logp(x)=Eq(z)[logp(x,z)q(z)]ELBO+DKL(q(z)p(zx))\log p(x) = \underbrace{\mathbb{E}_{q(z)} \left[\log \frac{p(x, z)}{q(z)} \right]}_{\text{ELBO}} + D_{\text{KL}}(q(z) \| p(z|x)) 可见,由于 KL 散度的非负性,因此,logp(x)ELBO\log p(x) \geq \text{ELBO} ,即 ELBO\text{ELBO}logp(x)\log p(x) 的一个下界,我们优化它来逼近真实分布,

五、如何求解(极大化 ELBO )?

这是一个泛函最大化的问题(注意 L(q)\mathcal{L}(q) 是一个泛函)。在求解之前,还需要一些前置知识。

5.1 平均场理论

平均场理论(Mean Field Theory),按照数学上的说法,平均场的适用范围只能是完全图,在这种情况下,系统中的任何一个个体以等可能接触其他个体。简单来说,就是 “把复杂问题拆成简单问题,每个变量只看其他变量的平均影响” ,每个变量间的局部作用对于全局的影响是可以忽略不计的

打个比方,假如你住在一个有 10 个房间的 big house ,每个房间都有暖气,你们想让整个房子的温度最舒适。不过现实情况很复杂,因为每个房间调的温度不仅影响自己,还会影响旁边房间,平均场的思想就是每个房间都只关心“其他房间的大致平均温度”,然后自己调整一下,反复几轮。

这样做虽然并不是最完美的,但是简单,且最终会收敛到一个“大家都比较满意”的状态。

5.2 Mean Field VI 平均场变分推断

平均场变分推断把复杂的后验分布 q(z)q(z) 分解为若干个子变量的独立乘积:q(z)=iqi(zi)q(z) = \prod_{i} q_i(z_i) ,也称为平均场分布族,注意这里每一个 qjq_j 相互独立,所以我们分别对每一个 qjq_j 进行优化,固定其他变量,拆分出 qj(zj)q_j(z_j) 来:

L(qj)=Eq(z)[logp(x,z)logq(z)]=iqi[logp(X,Z)ilogqi]dZ=qj[logp(X,Z)ijqidZi]dZjqjlogqjdZj+const=Eqj[logp~(X,Zj)]Eqj[logqj(Zj)]+const\begin{align*} \mathcal{L}(q_j) &= \mathbb{E}_{q(z)}\left[\log{p(x,z)-\log{q(z)}}\right]\\ &= \int \prod_{i}q_i\left[\log p(X, Z)-\sum_i\log q_i\right]dZ \\ &= \int q_j \left[\int\log p(X,Z)\prod_{i\ne j}q_i dZ_i\right]dZ_j - \int q_j\log q_j dZ_j + const\\ &= \mathbb{E}{q_j}[\log \tilde{p}(X, Z_j)] - \mathbb{E}{q_j}[\log q_j(Z_j)] + \text{const} \end{align*}

其中:

  • iqi=qjijqi\prod_i q_i = q_j \cdot \prod_{i \ne j} q_i
  • dZ=dZjdZjdZ = dZ_j \cdot dZ_{\ne j}
  • 其余的项都视为常数项,用 constconst 表示;
  • logp~(X,Zj):=logp(X,Z)ijqidZi\log \tilde{p}(X, Z_j) := \int \log p(X, Z) \prod_{i \ne j} q_i\, dZ_i

可以发现,常数以外的部分就是 DKL(qj(Zj)p~(X,Zj))-D_{KL}(q_j(Z_j)||\tilde{p}(X, Z_j)) ,最大化 L(q)\mathcal{L}(q) 的问题转换为最小化这个 KL 散度,亦知其在 qj(Zj)=p~(X,Zj)q_j(Z_j)=\tilde{p}(X, Z_j) 时取得最小值 0 ,所以,最优解为:

logqj(Zj)=Eij[logp(X,Z)]+const\log q_j^*(Z_j) = \mathbb{E}_{i \ne j} [\log p(X, Z)] + \text{const}

依次更新所有 qjq_j ,最终达到稳定。

六、变分推断的应用

有关变分推断的应用,包括变分自编码器等,等之后有空再继续填吧(也可能填到新的文章里)。总之关于变分推断的核心思想和目标,或者其类似的形式,将伴随你在传统机器学习、深度学习以及强化学习的各种 loss function 中。

后记

打公式太累了…另外笔者主要是按个人理解顺下来写的,限于个人水平,有打错或者有理解上的错误在所难免,欢迎大佬们在评论区指出!


【PRML】如何简单易懂地理解变分推断(Variational Inference)
https://blog.yokumi.cn/2025/07/02/【PRML】如何简单易懂地理解变分推断(Variational Inference)/
作者
Yokumi
发布于
2025年7月2日
更新于
2025年7月6日
许可协议