EDM

Dec 20, 2024
1 views
Generative Model

基于文章《Elucidating the Design Space of Diffusion-Based Generative Models》来统一扩散模型框架

💡 统一扩散模型框架:

通用扩散模型框架推导

加噪公式

  • Flow Matching的一步加噪公式
  • Score Matching的一步加噪公式
  • DDPM/DDIM的一步加噪公式
    其中, \(\mathbf{x}_0\) 都是原始图像, \(\sigma\sim\mathcal{N}(\mathbf{0}, \mathcal{I})\)

通用加噪公式形式探索

发现这三者存在一定的规律,写成一个通用形式:

那这个形式一定对应一个随机微分方程,这个方程的解可以描述 \(x_t\) 分布的变化:

根据DDPM和SMLD的推导结果,实际上 我们通常将漂移项简化为线性形式:

\[ \begin{equation} f(x_t,t) \rightarrow f(t)x_t \quad f(t):\mathbb{R}^1\rightarrow\mathbb{R}^1 \end{equation} \]

简化后的SDE形式为:

$$
\begin{equation}dx_t = f(t)x_tdt + g(t)dw\end{equation}

$$

均值

SDE均值定义:\(m(t) = E[x_t]\), 对应的均值的微分:

\[ \begin{equation}\frac{dm}{dt}=E[f(t,x_t)] \end{equation} \]

推导步骤:
对原始SDE两边取期望:

$$
E[dx_t] = E[f(t)x_t]dt = f(t)E[x_t]dt

$$

根据均值公式 \(dm= E[f(t,x_t)]dt\) 可以推导:

代入均值定义:

\[ dm(t) = E[f(t)x_t]dt = f(t)E[x_t]dt=f(t)m(t)dt \]

两边积分:

$$
\int \frac{1}{m}dm = \int_0^t f(r)dr + C

$$

求解得到:

$$
\ln|m| = \int_0^t f(r)dr + C

$$

指数化:

\[ m = e^{\int_0^t f(r)dr + C}=e^{\int_0^t f(r)dr} e^C=Ae^{\int_0^t f(r)dr} \]

代入初始条件 \(m(0) = x_0\)(因为\(t=0\)时对应的是原图,原图的均值就是\(x_0\))最终得到解:

$$
\begin{equation}m(t) = e^{\int_0^t f(r)dr}x_0\end{equation}

$$

所以在(1)式通用形式中的\(s(t)\)对应为:

$$
\begin{equation}s(t) = \exp{\int_0^t f(r)dr}\end{equation}

$$

协方差

SDE的协方差矩阵定义:$P(t) = E[(x_t - m)(x_t - m)^T]
$, 对应的协方差矩阵的微分:

\[ \begin{equation}\frac{dP}{dt} = E[f(x_t,t)(x_t - m)^T] + E[(x_t - m)f(x_t,t)^T] + E[g^2(t)]\end{equation} \]

上面式子的推导用了Itô公式(因为涉及随机过程),步骤略。

根据式3, 可将上面式8化简为:

\[ \frac{dP}{dt} = f(t)E[x_tx_t^T - x_tm^T] + f(t)E[x_tx_t^T - mx_t^T] + g^2(t) \]

再次考虑分离变量法,从\(P\)的定义出发, 可以得到:

$$
P = E[(x_t - m)(x_t - m)^T] = E[x_tx_t^T - mx_t^T - x_tm^T + mm^T]

$$

注意到:

\[ E[mm^T] = mm^T = mE[x_t^T] = E[mx_t^T] = E[x_t]m^T = E[x_tm^T] \]

所以有:

\[ E[x_tx_t^T - x_tm^T + x_tx_t^T - mx_t^T] = E[2x_tx_t^T - 2mm^T] = 2E[x_tx_t^T - mm^T] = 2P \]

这样式8就可以接着化简得到:

\[ \frac{dP}{dt} = 2f(t)P + g^2(t) \]

目前好像没办法接着化简了, 因为出现了\(f(t)P\),但这刚好又是另外一种方程,名叫一阶非齐次线性ODE,其标准形式为:

\[ \frac{dy}{dx} + G(x)y(x) = Q(x) \]

其对应关系为:

\[ y(x) \Rightarrow P(t), G(x) \Rightarrow -2f(t), Q(x) \Rightarrow g^2(t) \]

是有通解的,直接给出通解形式:

\[ y(x)=e^{-\int G(x)\mathrm{d}x}\int Q(x)e^{\int G(x)\mathrm{d}x}\mathrm{d}x+Ce^{-\int G(x)\mathrm{d}x} \]

代入上面式子中的信息,可得:

\[ \mathbf{p}(t)=e^{\int_0^t2f(r)\mathrm{d}r}\int_0^tg^2(r)e^{\int_0^t-2f(r)\mathrm{d}r}\mathrm{d}r+Ce^{\int_0^t2f(r)\mathrm{d}r} \]

代入初始条件\(\mathbf{P}(0)=0\)\(t=0\)时刻原图的方差为0),有:

\[ \begin{aligned}\mathbf{P}(0)&=e^{\int_0^02f(r)\mathrm{d}r}\int_0^0g^2(r)e^{\int_0^0-2f(r)\mathrm{d}r}\mathrm{d}r+Ce^{\int_0^02f(r)\mathrm{d}r}\\&=e^0*0+Ce^0\\&=C\end{aligned} \]

得到\(C=0\),因此有:

$$
\begin{equation}\mathbf{p}(t)=e^{\int_0^t2f(r)\mathrm{d}r}\int_0^tg^2(r)e^{\int_0^t-2f(r)\mathrm{d}r}\mathrm{d}r\end{equation}

$$

又已知\(s(t)=e^{\int_0^tf(r)\mathrm{d}r}\),有:

\[ \begin{equation}s^2( t) = e^{\int _0^t2f( r) \mathrm{d} r}\end{equation} \]

对应的 \(\frac 1{s^2( t) }= e^{\int _0^t- 2f( r) \mathrm{d} r}\), 又由加噪公式通用形式, 可得:

\[ \begin{equation}\sigma^2(t)=\int_0^t\frac{g^2(r)}{s^2(r)}\mathrm{d}r \end{equation} \]

小结

💡 至此, 我们来稍微总结一下,加噪公式就是将一个分布变为另外一个分布的桥梁,实际上也就是大家耳熟能详的流(Flow)。加噪公式的存在是为训练过程服务的,只有建立这个桥梁,才能明确模型的训练目标,扩散模型的学习才有可能。 EDM给定的通用加噪公式为:

扩散模型通用概率流常微分方程

有了前向随机微分⽅程的形式:

通过福克普朗克⽅程(Fokker–Planck Equation)的推导,可以得到SDE对应的概率流常微分⽅程(Probability Flow Ordinary Differential Equation,PFODE)。这个PFODE在确定起点\(\mathbf{x}_0\)(前向)或\(\mathbf{x}_N\)(逆向)的前提下,解的分布(也即\(p(\mathbf{x}_t)\)\(\mathbf{x}_t\)的边缘概率密度)与加噪过程SDE求得的解的分布是完全相同的。这个PFODE的形式为:

\[ \begin{equation}\begin{aligned}\mathrm{d}\mathbf{x}&=\Big[\mathbf{f}(\mathbf{x},t)-\frac{1}{2}g(t)^{2}\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})\Big]\mathrm{d}t\\&=\Big[\mathbf{f}(t)x_t-\frac{1}{2}g(t)^{2}\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})\Big]\mathrm{d}t\end{aligned}\end{equation} \]

推导的过程见这里:伊藤公式与Fokker-Planck方程

推导不含f(t)和g(t)的PFODE表达式

现在我们的⽬标就是:摆脱复杂理论束缚,也不搞那些随机微分⽅程,就单纯通过设计加噪公式,直接写出对应的PFODE,然后直接⽤ODE求解器采样,进⽽获得⽣成图像

假设我设计的⼀步加噪公式是下⾯这个通⽤形式

首先根据\(s(t)\)\(f(t)\)的关系,也就是上面7式, 进一步可得到:

\[ \ln s(t)=\int_0^tf(r)\mathrm{d}r \]

两边求导,积分上限函数 \(\int_0^tf(r)\mathrm{d}r\)\(f(t)\)的一个原函数

\[ \begin{aligned}&\frac1{s(t)}\dot{s}(t)=f(t)\end{aligned} \]

即:

\[ \begin{equation}f(t)=\frac{\dot{s}(t)}{s(t)}\end{equation} \]

再根据\(\sigma(t)\)\(f(t)\)\(g(t)\)的关系, 也就是上面11式, 两边求导,积分上限函数是 \(\frac{g^2(t)}{s^2(t)}\) 的一个原函数

\[ \begin{aligned}&2\sigma(t)\dot{\sigma}(t)=\frac{g^2(t)}{s^2(t)}\\&g(t)=\sqrt{2\sigma(t)\dot{\sigma}(t)s^2(t)}\end{aligned} \]

即:

\[ \begin{equation}g(t)=s(t)\sqrt{2\sigma(t)\dot{\sigma}(t)}\end{equation} \]

至此,\(f(t)\)\(g(t)\) 可以完全由 \(s(t)\)\(\sigma(t)\) 表示,反过来也可以表示。这样看似代入公式(13)就能解决问题了,但实际上还差一步

虽然搞定了\(f(t)\)\(g(t)\),但是 \(p_t(\mathbf{x}_t)\) 是未知的,换句话说这种边缘分布如果能够知道,直接从里面采样即可,完全没必要这么复杂通过迭代的方式逐步获得结果。所以,还要对 \(\nabla\mathbf{x}_t\log p_t(\mathbf{x}_t)\) 进行分析,首先考虑边缘概率密度 \(p_t(\mathbf{x}_t)\)

\[ \begin{equation}\begin{aligned}p_{t}(\mathbf{x}_{t})&=\int_{\mathbb{R}^d}p_{\mathrm{data}}(\mathbf{x}_0)p_{0t}(\mathbf{x}_t|\mathbf{x}_0)\mathrm{d}\mathbf{x}_0\quad \text{全概率密度公式} \\&=\int_{\mathbb{R}_d}p_{\mathrm{data}}(\mathbf{x}_0)\left[\mathcal{N}\left(\mathbf{x}_t;s(t)\mathbf{x}_0,s^2(t)\sigma^2(t)\mathbf{I}\right)\right]\mathrm{d}\mathbf{x}_0 \\&=\int_{\mathbb{R}_{d}}p_{\mathrm{data}}(\mathbf{x}_{0})\left[\underbrace{s^{-d}(t)}_{\text{保证概率密度函数积分为}1}\mathcal{N}\left(\frac{\mathbf{x}_{t}}{s(t)};\mathbf{x}_{0},\sigma^{2}(t)\mathbf{I}\right)\right]\mathrm{d}\mathbf{x}_{0} \\&=s^{-d}(t)\int_{\mathbb{R}d}p_{\mathrm{data}}(\mathbf{x}_0)\mathcal{N}\left(\frac{\mathbf{x}_t}{s(t)};\mathbf{x}_0,\sigma^2(t)\mathbf{I}\right)\mathrm{d}\mathbf{x}_0 \\&=s^{-d}(t)\int_{\mathbb{R}_d}p_{\mathrm{data}}(\mathbf{x}_0)\mathcal{N}\left(\frac{\mathbf{x}_t}{s(t)}-\mathbf{x}_0;0,\sigma^2(t)\mathbf{I}\right)\mathrm{d}\mathbf{x}_0\quad\text{不影响概率} \\&=s^{-d}(t)\underbrace{\left[p_{\mathrm{data}}*\mathcal{N}\left(0,\sigma^2(t)\mathbf{I}\right)\right]}_\text{分布卷积运算}{\left(\frac{\mathbf{x}_t}{s(t)}\right)}\end{aligned}\end{equation} \]

上式就是边缘概率密度\(p_t(\mathbf{x}_t)\)的表达式。由于分布卷积运算等价于两个分布"相叠加”,所以实际上分布卷积运算后的分布就等于SMLD方法的加噪公式,只是随机变量 \(\mathbf{x}_t\) 除以一个系数 \(s(t)\),概率密度整体乘以了一个\(s^{-d}(t)\)。令

\[ p( \mathbf{x}_t; \sigma ( t) ) = \left [ p_\text{data}* \mathcal{N} \left ( 0, \sigma ^2( t) \mathbf{I} \right ) \right ] ( \mathbf{x} _t) = \mathcal{N} \left ( \mathbf{x} _t; \mathbf{x} , \sigma ^2( t) \mathbf{I} \right ) = p( \mathbf{x} _t| \mathbf{x} ) \]

注意这里为了方便说明问题假设 \(p_\text{data}= \mathcal{N} ( \mathbf{x} , 0)\), 也即数据集只有一个数据的时候,最后两个等号才成立。若数据集包含很多数据,参考EDM论文公式(45)可进行更严谨的推导, 可以看出这个式子就是score matching中的条件概率分布函数,也就是说边缘概率密度和加噪的条件概率密度等价

把公式(15)(16)(17)代入公式(13)中,可得:

\[ \begin{aligned}\mathrm{d}\mathbf{x}_{t}&=\left[f(t)\mathbf{x}_t-\frac12g^2(t)\nabla{\mathbf{x}_t}\log p_t(\mathbf{x}_t)\right]\mathrm{d}t\\&=\left[\frac{\dot{s}(t)}{s(t)}\mathbf{x}_t-s^2(t)\sigma(t)\dot{\sigma}(t)\nabla{\mathbf{x}_t}\log\left(s^{-d}(t)\left[p_\text{data}*\mathcal{N}\left(0,\sigma^2(t)\mathbf{I}\right)\right]\left(\frac{\mathbf{x}_t}{s(t)}\right)\right)\right]\mathrm{d}t\\&=\left[\frac{\dot{s}(t)}{s(t)}\mathbf{x}_t-s^2(t)\sigma(t)\dot{\sigma}(t)\left(\nabla{\mathbf{x}_t}\log s^{-d}(t)+\nabla{\mathbf{x}_t}\log p\left(\frac{\mathbf{x}_t}{s(t)};\sigma(t)\right)\right)\right]\mathrm{d}t\end{aligned} \]

再进一步, 我们就得到了通用概率流常微分方程(PFODE):

\[ \begin{equation}dx_t=\left[\frac{\dot{s}(t)}{s(t)}\mathbf{x}_t-s^2(t)\sigma(t)\dot{\sigma}(t)\nabla{\mathbf{x}_t}\log p\left(\frac{\mathbf{x}_t}{s(t)};\sigma(t)\right)\right]\mathrm{d}t\end{equation} \]

小结

💡 至此,通过将SDE转换为PFODE, 并结合对边缘概率分布 \(p_t(x)\)的推导,我们获得了仅仅依赖通用加噪公式中的\(s(t)\)\(\sigma(t)\), 而不显式依赖\(f(t)\)\(g(t)\)的通用概率流常微分方程PFODE