文章从连续情形出发开始介绍重参数,主要的例子是正态分布的重参数;然后引入离散分布的重参数,这就涉及到了Gumbel Softmax,包括Gumbel Softmax的一些证明和讨论;最后再讲讲重参数背后的一些故事,这主要跟梯度估计有关。
基本概念
重参数(Reparameterization)实际上是处理如下期望形式的目标函数的一种技巧:
这样的目标在VAE中会出现,在文本GAN也会出现,在强化学习中也会出现(\(f(z)\) 对应于奖励函数),所以深究下去,我们会经常碰到这样的目标函数。取决于 \(z\) 的连续性,它对应不同的形式:
当然,离散情况下我们更喜欢将记号 \(z\) 换成 \(y\) 或者 \(c\)。
为了最小化\(L_{\theta}\),我们就需要把 \(L_{\theta}\) 明确地写出来,这意味着我们要实现从 \(p_{\theta}(z)\) 中采样,而 \(p_{\theta}(z)\) 是带有参数 \(\theta\) 的,如果直接采样的话,那么就失去了 \(\theta\) 的信息(梯度),从而无法更新参数 \(\theta\)。而Reparameterization则是提供了这样的一种变换,使得我们可以直接从 \(p_{\theta}(z)\) 中采样,并且保留 \(\theta\) 的梯度。(注:如果考虑最一般的形式,那么\(f(z)\)也应该带上参数\(\theta\),但这没有增加本质难度。)
连续情形
简单起见,我们先考虑连续情形
其中 \(p_{\theta}(z)\) 是具有显式概率密度表达式的分布,在变分自编码器中常见的是正态分布 \(p_{\theta}(z)=\mathcal{N}\left(z;\mu_{\theta},\sigma_{\theta}^2\right)\)。
形式
从 式3 中知道,连续情形的 \(L_{\theta}\) 实际上就对应一个积分,所以,为了明确写出 \(L_{\theta}\),有两种途径:最直接的方式是精确地完成积分式3 ,得到显式表达式,但这通常都是不可能的了;所以,唯一的办法是转化为采样形式 式1,并试图在采样过程中保留 \(\theta\) 的梯度。
重参数就是这样的一种技巧,它假设从分布 \(p_{\theta}(z)\) 中采样可以分解为两个步骤:
- 从无参数分布 \(q(\varepsilon)\) 中采样一个\(\varepsilon\);
- 通过变换 \(z=g_{\theta}(\varepsilon)\) 生成 \(z\)。那么,式1 就变成了
$$
\begin{equation}L_{\theta}=\mathbb{E}{\varepsilon\sim q(\varepsilon)}[f(g{\theta}(\varepsilon))]\end{equation}
$$
这时候被采样的分布就没有任何参数了,全部被转移到 \(f\) 内部了,因此可以采样若干个点,当成普通的loss那样写下来了。
例子
一个最简单的例子就是正态分布:对于正态分布来说,重参数就是“从\(\mathcal{N}\left(z;\mu_{\theta},\sigma_{\theta}^2\right)\) 中采样一个 \(z\) ”变成“从 \(\mathcal{N}\left(\varepsilon;0, 1\right)\) 中采样一个 \(\varepsilon\),然后计算\(\varepsilon\times \sigma_{\theta} + \mu_{\theta}\)”,所以
如何理解直接采样没有梯度而重参数之后就有梯度呢?其实很简单,比如我说从\(\mathcal{N}\left(z;\mu_{\theta},\sigma_{\theta}^2\right)\) 中采样一个数来,然后你跟我说采样到5,我完全看不出5跟 \(\theta\) 有什么关系呀(求梯度只能为0);但是如果先从 \(\mathcal{N}\left(\varepsilon;0, 1\right)\) 中采样一个数比如0.2,然后计算 \(0.2 \sigma_{\theta} + \mu_{\theta}\),这样我就知道采样出来的结果跟\(\theta\) 的关系了(能求出有效的梯度)。
总结
让我们把前面的内容重新整理一下。总的来说,连续情形的重参数还是比较简单的:连续情形下,我们要处理的 \(L_{\theta}\) 实际上是 式3 ,由于精确的积分我们没有办法显式地写出来,所以需要转化为采样,而为了在采样的过程中得到有效的梯度,我们就需要重参数。
从数学本质来看,重参数是一种积分变换,即原来是关于 $z $ 积分,通过 \(z=g_{\theta}(\varepsilon)\) 变换之后得到新的积分形式,
离散情形
为了突出“离散”,我们将随机变量 \(z\) 换成 \(y\),即对于离散情形要面对的目标函数是
其中离散意味着一般情况 \(y\) 是可枚举的,换句话说 \(p_{\theta}(y)\) 此时是一个 \(k\) 分类模型:
其中各个 \(o_i\) 是 \(\theta\) 的函数。
分析
读者看到 式6 中的求和,第一反应可能是“求和?那就求呗,又不是求不了”。
的确,这也是笔者当时看到它的第一反应。与连续情形的 式3 不一样,式3 如果直接硬杠的话需要完成积分(也可以看成无穷多个点的求和),我们没法做到这一点。但是对于离散的 式6 ,只不过是有限项求和,理论上确实可以直接完成求和再去梯度下降。
但是,如果 \(k\) 特别大呢?举个例子,假设 \(y\) 是一个100维的向量,每个元素不是 0 就是 1(二元变量),那么所有不同的y的总数目就是\(2^{100}\),要对这样的 \(2^{100}\) 个单项进行求和,计算量是难以接受的;还有一个典型的例子是seq2seq的解码端(如果要做文本GAN就需要面对它),它的类别总数目是\(|V|^l\),其中\(|V|\)是词表大小而 \(l\) 是句子长度。这样的情况下,直接完成精确的求和都是难以实现的。
形式
所以,还是需要回到采样上去,如果能够采样若干个点就能得到 式6 的有效估计,并且还不损失梯度信息,那自然是最好了。为此,需要先引入Gumbel Max,它提供了一种从类别分布中采样的方法。
假设每个类别的概率是\(p_1,p_2,\dots,p_k\),那么下述过程提供了一种依概率采样类别的方案,称为Gumbel Max:
也就是说,先算出各个概率的对数 \(\log p_i\),然后从均匀分布\(U[0,1]\) 中采样 \(k\) 个随机数 \(\varepsilon_1,\dots,\varepsilon_k\),把 \(-\log(-\log \varepsilon_i)\) 加到 \(\log p_i\) 上去,最后把最大值对应的类别抽取出来就行了。
后面我们会证明,这样的过程精确等价于依概率 \(p_1,p_2,\dots,p_k\) 采样一个类别,换句话说,在Gumbel Max 中,输出 \(i\) 的概率正好是 \(p_i\)。由于现在的随机性已经转移到 \(U[0,1]\)上去了,并且\(U[0,1]\) 不带有未知参数,因此Gumbel Max就是离散分布的一个重参数过程。
但是,我们希望重参数不丢失梯度信息,但是Gumbel Max做不到,因为 \(\mathop{\text{argmax}}\) 不可导,为此,需要做进一步的近似。首先,留意到在神经网络中,处理离散输入的基本方法是转化为one hot形式,包括Embedding层的本质也是one hot全连接(参考词向量与Embedding),因此 \(\mathop{\text{argmax}}\) 实际上是 \(\text{onehot}(\mathop{\text{argmax}}))\),然后,我们寻求\(\text{onehot}(\mathop{\text{argmax}}))\)的光滑近似,它就是softmax(参考函数光滑化:不可导函数的可导逼近)。
由此,我们得到Gumbel Max的光滑近似版本——Gumbel Softmax:
其中参数 \(\tau > 0\) 称为退火参数,它越小输出结果就越接近one hot形式(但同时梯度消失就越严重)。提示一个小技巧,如果 \(p_i\) 是softmax的输出,即 式7 的形式,那么大可不必先算出 \(p_i\) 再取对数,直接将 \(\log p_i\) 替换为 \(o_i\) 即可:
💡 *Gumbel Max的证明:*
例子
跟连续情形一样,Gumbel Softmax就是用在需要求 \(\mathbb{E}_{y\sim p_{\theta}(y)}[f(y)]\)、且无法直接完成对 \(y\) 求和的场景,这时候我们算出 \(p_{\theta}(y)\)(或者\(o_i\)),然后选定一个\(\tau > 0\),用Gumbel Softmax算出一个随机向量来 \(\tilde{y}\),代入计算得到 \(f(\tilde{y})\),它就是 \(\mathbb{E}_{y\sim p_{\theta}(y)}[f(y)]\) 的一个好的近似,且保留了梯度信息。
注意,Gumbel Softmax不是类别采样的等价形式,Gumbel Max才是。而Gumbel Max可以看成是Gumbel Softmax在 \(\tau \to 0\) 时的极限。所以在应用Gumbel Softmax时,开始可以选择较大的 \(\tau\)(比如1),然后慢慢退火到一个接近于0的数(比如0.01),这样才能得到比较好的结果。
下面提供一个自己实现的离散隐变量的VAE例子:
https://github.com/bojone/vae/blob/master/vae_keras_cnn_gs.py
效果图:

溯源
Gumbel Max由来已久,但首次提出并应用Gumbel Softmax的是论文《Categorical Reparameterization with Gumbel-Softmax》,这篇论文主要探讨了部分隐变量是离散型变量的变分推断问题,比如基于VAE的半监督学习。其后,在文章《GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution》中,Gumbel Softmax首次被用在离散序列生成,但还不是文本生成,而是比较简单的人造字符序列。
其后,SeqGAN被提出,自那以后文本GAN模型一直以与强化学习结合的方式出现,基于Gumbel Softmax的纯深度学习和梯度下降的方法相对沉寂,直到RelGAN的出现。RelGAN是ICLR 2019提出的模型,它提出了新型的生成器和判别器结构,使得直接用Gumbel Softmax训练出的文本GAN大幅度超过了以往的各种文本GAN模型。关于RelGAN,我们后面有机会再谈。
总结
这部分内容主要介绍的是Gumbel Softmax,它是离散情形下 式1 型损失的一个重参数技巧。
理论上来说,离散情形的 式1 只是有限项求和,不一定需要重参数。但事实上,“有限”也可能是相当大的数字,因此遍历求和可能难以进行,所以还是要转化为采样形式,从而需要重参数技巧,这就是Gumbel Softmax,源于对Gumbel Max的光滑化。
除了上述视角外,还有一个辅助的视角:Gumbel Softmax通过\(\tau\to 0\)的退火来逐渐逼近one hot,相比直接用原始的Softmax进行退火,区别在于原始Softmax退火只能得到最大值位置为1的one hot向量,而Gumbel Softmax有概率得到非最大值位置的one hot向量,增加了随机性,会使得基于采样的训练更充分一些。
背后的故事
重参数就这样介绍完了吗?远远没有,重参数的背后,实际上是一个称为“梯度估计(gradient estimator)”的大家族,而重参数只不过是这个大家族中的一员。每年的ICLR、ICML等顶会上搜索gradient estimator、REINFORCE等关键词,可以搜索到不少文章,说明这是个大家还在钻研的课题。
要想说清重参数的来龙去脉,也要说些梯度估计的故事。
SF估计
前面我们分别讲了连续型和离散型的重参数,都是在“loss层面”讲述的,也就是说都是想办法把loss显式地定义好,剩下的交给框架自动求导、自动优化就是了。而事实上,就算不能显式地写出loss函数,也不妨碍我们对它求导,自然也不妨碍我们去用梯度下降了。比如
现在我们得到了梯度的一个估计式,称为“SF估计”,全称是Score Function Estimator,这是对原来损失函数的最朴素的估计,在强化学习中z代表着策略,那么上式就是一个最基本的策略梯度,所以有时候也直接称上述估计为叫REINFORCE。要注意,对离散情形的损失函数重新推导一遍,结果也是一样的,也就是说,上述结果是通用的,不区分z是连续变量还是离散变量。现在我们可以直接从\(p_{\theta}(z)\) 中采样若干个点来估算式 式15 的值了,不用担心会不会没梯度,因为式\eqref{eq:sf}本身就是梯度了。
梯度方差
看上去很美好,得到了一个连续和离散变量都适用的估计式,那为什么还需要重参数呢?
主要的原因是:SF估计的方差太大。 式15 是函数\(f(z) \frac{\partial}{\partial\theta} \log p_{\theta}(z)\) 在分布 \(p_{\theta}(z)\)下的期望,我们要采样几个点来算(理想情况下,希望只采样一个点),换句话说,我们想用下面的近似
于是问题就来了:这样的梯度估计方差很大。
什么是方差很大?它有什么影响?举个简单的例子,假如\(\alpha = avg([4, 5, 6]) = avg([0, 5, 10])\),也就是说,我们的目标 \(\alpha\) 是三个数的平均值,这三个数要不就是4,5,6,要不就是0,5,10,在精确估计的情况下,两者是等价的,但是如果每一组只能随机选其中一个数呢?第一组可能选到4,这也没什么,跟准确值5只差一点;但是第二组可能选到0,这跟准确值5差得就有点大了。也就是说,随机选一个的情况下,第二组估计的波动(方差)太大了。类似地,SF估计出来的梯度方差也是如此,这导致了我们用梯度下降优化的时候相当不稳定,非常容易崩。
降方差
从形式上看,式\eqref{eq:sf}是非常漂亮的,本身形式不复杂,而且对离散变量和连续变量都通用,还对f没有特别要求(相反,重参数要求f可导,但是在诸如强化学习的场景下,f(z)对应着奖励函数,很难做到光滑可导)。所以,很多文章探讨基于 式15 的降方差技巧,论文《Categorical Reparameterization with Gumbel-Softmax》就列举了一些,近几年来也有一些新发展,总之,还是那句话,大家搜索gradient estimator、REINFORCE等关键词,就有不少文章了。
重参数是另一种降方差技巧,为此,我们写出重参数后的\eqref{eq:reparam}的梯度表达式:
对比SF估计的 式15 ,我们可以直观感知为什么上式方差更小了:
- SF估计中包含了\(\log p_{\theta}(z)\),我们知道,作为一个合理的概率分布,一般都在无穷远处(即\(\Vert z\Vert \to \infty\))都会有\(p_{\theta}(z)\to 0\),取了 \(\log\) 之后反而会趋于负无穷,换句话说,\(\log p_{\theta}(z)\)这一项实际上放大了无穷远处的波动,从而一定程度上增加了方差;
- SF估计中包含的是f而重参数之后变成了\(\frac{\partial f}{\partial g}\),\(f\) 一般是神经网络,而通常我们定义的神经网络模型其实都是\(\mathcal{O}(z)\)级别的模型,从而我们可以预期它的梯度是\(\mathcal{O}(1)\)级别的(不严格成立,只能说在平均意义下基本成立),所以相对情况下更平稳一些,因此f的方差也比\(\frac{\partial f}{\partial g}\) 的方差要大。
鉴于这两个理由,我们就可以得出,一般情况下重参数之后梯度估计的方差会比SF估计要小。注意,这里还是要强调“一般情况”,换言之,“重参数降低梯度估计的方差”这个结论不是绝对成立的,上述两个理由都是在一般情况下(我们面对的多数模型)成立,如果非要较劲,我们总能构造出重参数反而增加方差的例子。
文章小结
经过一番长篇大论,我们总算把重参数的故事基本上都捋清楚了。更深入地理解重参数技巧,是更好地理解VAE及文本GAN的必经之路。
从loss层面看,我们需要分连续和离散两种情形:连续情形下,重参数是用采样形式且不损失梯度地写出loss的方法;离散情形下,重参数有着跟连续情形一样的作用,不过更根本的原因是降低计算量(否则直接遍历求和也行)。从梯度估计层面看,重参数是降低梯度估计方差的一种有效手段,而同时还有其他的降低方差手段也被不少学者研究中。