写在前面
照搬了一部分去年《概率论与数理统计》课程自己的期中论文,当时啃了大名鼎鼎的经典著作 《模式识别与机器学习》(Pattern Recognition and Machine Learning,PRML) 大概三四章,没按顺序,并且啃地也较快,后面也没时间填坑。现在继续来填概率论的大坑,不过应该也不会特别细,楼主太菜了,能理解就理解,至于一些需要记忆的一般偏向于用到了现查。
一、为什么需要变分推断?
先来简单地回顾一下贝叶斯推断,即已知 x ,我们希望知道潜在变量 z 的概率分布,即所谓的后验概率分布 p(z∣x) :
p(z∣x)=p(x)p(x∣z)p(z)=∫p(x,z)dzp(x,z)
其中:
- x :观察数据 ;
- z :潜在变量(latent variable);
- p(x,z)=p(x∣z)p(z) :联合分布;
- p(z∣x) :后验分布;
- p(x)=∫p(x,z)dz :边缘似然;
实际情况中,往往由于变量维度较高,或者被积函数过于复杂,导致精确计算边缘似然的代价过高。这时候就需要进行近似推断,典型代表之一就是变分推断。
二、变分推断的核心思想
用一个可调的、简化的分布 q(z) 去逼近真实后验 p(z∣x),使它们越像越好。计算后验概率的问题转换为一个最优化问题,就是让 q(z) 尽可能接近 p(z∣x) ,而衡量两个分布之间的差异性,就要用到 KL 散度了。
三、何为变分?
变分听起来是一个很高深的名词,实则也确实一点都不简单(对笔者而言)。
我们可以把函数 y(x) 看成⼀个运算符。对于任意输⼊ x ,这个运算符都能返回⼀个输出 y ,或者说 y 是 关于 x 的一个实数算子,将实数 x 映射到实数 y(x) 。
类比一下,可以定义一个泛函 F[y] ,它是一个运算符,它以 y(x) 作为输入,输出 F 。这样讲可能比较抽象,所以举几个例子:
- 泛函表示二维空间中一条曲线的长度,曲线的轨迹根据函数来定义;
- 在 ML 领域,常见的就是随机变量 x 的熵 H(x) ;
两个点之间有无数条路径,每一条路径都是一个函数,当函数变化了一点而导致泛函值变化了多少,这就是变分。传统微积分中,希望找到一个 x 使得 y(x) 取到最大值或最小值。同样,变分法中,希望找到一个函数 y(x) 最大化或最小化泛函 F[y] 。变分法可以⽤来说明两点之间的最短路径是⼀条直线,或者最⼤熵分布是⾼斯分布。
四、变分推断
4.1 推断目标
上文已经定义了 x 和 z ,这里再来明确一下,在 ML 领域,x 表示观测变量(也叫输入变量、证据变量等,叫法无所谓),z 表示隐变量(就是希望推断的 Label)。放在回归模型中,分别就是输入值和预测值;放在分类问题中,就是输入的图像等和输出等类别。
重新回顾一下我们的目标:
MinimizeDKL(q(z)∣∣p(z∣x))
注意这里,我们使用了反向 KL 散度(即 DKL(q∣∣p) ,用前面的分布 q 近似后面的分布 p ),KL 散度不具有对称性,这与一般的正向 KL 散度(即 DKL(p∣∣q) ,用后面的分布 q 近似前面的分布 p )是不同的。我们为什么选择反向 KL 散度呢?
4.2 为什么选择反向 KL 散度?
对于正向 KL 散度,即 DKL(p(z∣x)∣∣q(z))=∑zp(z∣x)logq(z)p(z∣x) 中,首先你会发现一个很大的问题就是,对数外面的后验分布 p(z∣x) 并没有办法像 4.3 节中一样化掉,无法计算!
问题不止如此,正向 KL 散度中,如果 p(z∣x)>0 ,如果 q(z)→0 ,那么 DKL(p∣∣q)→+∞ ,考虑如果 p(z∣x) 是一个双峰的分布,我们必须让 q(z) 同时逼近所有的峰,如果只贴近一个峰,那就会导致另一个峰的部分 q(z)→0 ,KL 散度无穷大,这并不是我们所希望的。
那我们再来看反向 KL 散度解决了什么问题?首先,它是可计算的,在 4.3 节对 ELBO 的推导中天然出现;其次,反向 KL 散度有 “零强迫” 性质,即当 q(z)>0 的地方,如果 p(z∣x) 很小,就会被强烈惩罚(因为 log0=−∞),所以它会逼近 p(z∣x) 的高概率区域,面对多峰的分布,一般会收敛于主峰而忽略次要峰。
至于为什么倾向于选择单峰而非多峰,大概是因为其容易收敛,且一般不会出现去靠近多个峰而导致每个峰的近似效果都不太好。
4.3 边缘似然下界
p(z∣x) 并不容易求出,所以对反向 KL 散度进行变换:
DKL(q(z)∥p(z∣x))=z∑q(z)logp(z∣x)q(z)=Eq(z)[logp(z∣x)q(z)]=Eq(z)[logp(x,z)q(z)⋅p(x)]=Eq(z)[logp(x)−logq(z)p(x,z)]=logp(x)−Eq(z)[logq(z)p(x,z)]
由于输入变量的分布已知,即 logp(x) 为常量,最小化 KL 散度的目标转换为:
maxL(q)⇔maxEq(z)[logq(z)p(x,z)]⇔maxEq(z)[logp(x,z)−logq(z)]
L(q) 又被称为边缘似然下界/证据下界(Evidence Lower Bound,ELBO),为什么称其为下界呢?
由 logp(x)=ELBOEq(z)[logq(z)p(x,z)]+DKL(q(z)∥p(z∣x)) 可见,由于 KL 散度的非负性,因此,logp(x)≥ELBO ,即 ELBO 是 logp(x) 的一个下界,我们优化它来逼近真实分布,
五、如何求解(极大化 ELBO )?
这是一个泛函最大化的问题(注意 L(q) 是一个泛函)。在求解之前,还需要一些前置知识。
5.1 平均场理论
平均场理论(Mean Field Theory),按照数学上的说法,平均场的适用范围只能是完全图,在这种情况下,系统中的任何一个个体以等可能接触其他个体。简单来说,就是 “把复杂问题拆成简单问题,每个变量只看其他变量的平均影响” ,每个变量间的局部作用对于全局的影响是可以忽略不计的。
打个比方,假如你住在一个有 10 个房间的 big house ,每个房间都有暖气,你们想让整个房子的温度最舒适。不过现实情况很复杂,因为每个房间调的温度不仅影响自己,还会影响旁边房间,平均场的思想就是每个房间都只关心“其他房间的大致平均温度”,然后自己调整一下,反复几轮。
这样做虽然并不是最完美的,但是简单,且最终会收敛到一个“大家都比较满意”的状态。
5.2 Mean Field VI 平均场变分推断
平均场变分推断把复杂的后验分布 q(z) 分解为若干个子变量的独立乘积:q(z)=∏iqi(zi) ,也称为平均场分布族,注意这里每一个 qj 相互独立,所以我们分别对每一个 qj 进行优化,固定其他变量,拆分出 qj(zj) 来:
L(qj)=Eq(z)[logp(x,z)−logq(z)]=∫i∏qi[logp(X,Z)−i∑logqi]dZ=∫qj∫logp(X,Z)i=j∏qidZidZj−∫qjlogqjdZj+const=Eqj[logp~(X,Zj)]−Eqj[logqj(Zj)]+const
其中:
- ∏iqi=qj⋅∏i=jqi ;
- dZ=dZj⋅dZ=j ;
- 其余的项都视为常数项,用 const 表示;
- logp~(X,Zj):=∫logp(X,Z)∏i=jqidZi
可以发现,常数以外的部分就是 −DKL(qj(Zj)∣∣p~(X,Zj)) ,最大化 L(q) 的问题转换为最小化这个 KL 散度,亦知其在 qj(Zj)=p~(X,Zj) 时取得最小值 0 ,所以,最优解为:
logqj∗(Zj)=Ei=j[logp(X,Z)]+const
依次更新所有 qj ,最终达到稳定。
六、变分推断的应用
有关变分推断的应用,包括变分自编码器等,等之后有空再继续填吧(也可能填到新的文章里)。总之关于变分推断的核心思想和目标,或者其类似的形式,将伴随你在传统机器学习、深度学习以及强化学习的各种 loss function 中。
后记
打公式太累了…另外笔者主要是按个人理解顺下来写的,限于个人水平,有打错或者有理解上的错误在所难免,欢迎大佬们在评论区指出!