Reading

Consistency Models

Diffusion Models from SDE

连续扩散模型 (Continuous Diffusion Models) 将传统的离散时间扩散过程扩展到连续时间域,可以被视为一个随机过程,使用随机微分方程(SDE)来描述。其前向过程可以写成如下形式:

\[\mathrm d\mathbf x=\mathbf f(\mathbf x,t)\mathrm dt+g(t)\mathrm d\mathbf w\tag{1}\]

其中,\(f(x,t)\) 可以看成偏移系数,\(g(t)\) 可以看成是扩散系数,\(dw\) 是标准布朗运动。这个SDE
描述了数据在连续时间域内如何被噪声逐渐破坏。
这个随机过程的逆向过程存在(更准确的描述:下面的逆向时间SDE具有与正向过程SDE相同的联合分布)为

\[d\mathbf{x}=[\mathbf{f}(\mathbf{x},t)-g^2(t)\nabla_{\mathbf{x}}\log p_t(\mathbf{x})]dt+g(t)d\bar{\mathbf{w}}\tag{2}\]

前面我们得到了扩散过程的逆向过程可以用一个SDE描述(逆向随机过程),事实上,存在一个确定性过程 (用ODE描述)也是它的逆向过程 (更准确的描述:这个ODE过程的在任意时刻\(t\in[0,T]\) 的状态的边缘分布 \(p_t( \mathbf{x} ( t) )\) 与 SDE过程的相同

\[\mathrm d\mathbf x=\begin{bmatrix}f(\mathbf x,t)-\frac{1}{2}g^2(t)\nabla_{\mathbf x}\log p_t(\mathbf x)\end{bmatrix}\mathrm dt\tag{3}\]

其中,\(\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})\) 叫做score function , 可以通过score match进行训练,也兼容普通的扩散模型噪声预测模型。

这个ODE称为概率流ODE(probability flow ODE,PF ODE,它沿着概率流的轨迹(也就是ODE的解函数) 建立了从噪声分布中的点\(\mathbf{x}(T)\sim p_T(\mathbf{x}(T))\) 与到数据分布中的点 \(\mathbf{x}(0)\sim p_T(\mathbf{x}(T))\)的映射,也就是说PF ODE建立了高斯分布样本与数据分布样本的映射。

因此,采样时,把\(s_{\theta}(\mathbf{x},t)\approx\nabla\log p_{t}(\mathbf{x})\) 代入PF-ODE, 仅需采样一次\(\mathbf{x}(T)\sim p_T(\mathbf{x}(T))\), 然后可以确定性的解出一条从\(\mathbf{x}(T)\)\(\mathbf{x}(0)\)的轨迹,这样的到的\(\mathbf{x}(0)\)的分布就是data分布。可以用ODE数值求解器来求解出\(\mathbf{x}(0)\) ,例如欧拉法,Runge- Kutta法之类的数值方法。

Consistency Models

有了PF ODE,我们已经能做到采样一个高斯噪声,然后通过求解ODE,映射到数据。但求解ODE 仍然是一个迭代过程,要算很多次\(\mathbf{s}_\theta(\mathbf{x},t)\) ,效率很低。我们能不能直接拟合出来这个ODE表示的映射 (也就是ODE的解函数),然后一步到位?Consistency models (一致性模型)就是相当于直接学了个ODE的解:任给某个轨迹上的点\((\mathbf{x}(t),t)\) ,一步inference得到\(\mathbf{x}(0)\) !
首先我们回顾扩散过程SDE。在Consistency Models这篇论文里,为了最后的PF ODE形式的简单,他们采用了扩散项系数\(g(t)=\sqrt{2t}\) ,漂流项 \(\mathbf f(\mathbf x,t)\) 为0,

这样 \(p_{t}(x)=p_{data}(x)\otimes\mathcal{N}(0,t^{2}I)\),其中\(\otimes\)是卷积操作,\(\pi(x)=p_T(x)=\mathcal{N}(0,T^{2}I)\)。用 score model \(s_{\theta}(x,t)\approx\nabla\log p_{t}(x)\) 进行估计。在上述设置下, 带入到(1)(3), 可以写作:

\[\text{SDE:}\quad dx_t=\sqrt{2t}d\boldsymbol{w}t\tag{4}\]
\[\text{PF-ODE:}\quad\frac{dx_t}{dt}=-ts_\theta(x_t,t)\tag{5}\]

(5)为 empirical PF ODE。

image

对于(4)的 SDE, 可以对两边进行\([0,T]\)的积分:

\[\begin{aligned}&\int_{0}^{T}dx_{t}=\int_{0}^{T}\sqrt{2t}dw_{t}\\&x_{T}-x_{0}=\mathcal{N}(0,\int_{0}^{T}2tdt)\\&x_{T}-x_{0}=\mathcal{N}(0,T^{2}I)\\&x_{T}=x_0+\mathcal{N}(0,T^2I)\end{aligned}\]

其中第二行利用的是随机微分方程SDE和基于 Ito 积分的性质,也可以理解为这个等式右侧是对正态分布进行积分,相当于无穷个带有系数 \(\sqrt{2t}\) 的正态分布的相加,那么根据正态分布的线性性质,可以知道是无穷个\(\mathcal{N}(0,2tI)\)个正态分布相加,均值还是 0,方差为 \(\int_0^T2tdt=T^2\)最后一行相当于一个加噪的过程,由于方差 \(T^2\) 一般比较大,所以可以近似\(x_T\sim\mathcal{N}(0,T^2I)\) ,同样的有任意时间步 \(t\)的加噪过程 \(x_t=x_0+\mathcal{N}(0,t^2I)\)

\(\hat{x}_T\sim\pi(x)=\mathcal{N}(0,T^2I)\) 中采样初始噪声,然后使用 Euler/Heun 求解器求解(5), 可以得到一条采样轨迹\(\{\hat{x}_t\}_{t\in[0,T]}\) 。为了数值稳定性,通常在 \(t=\epsilon\) 处停止采样得到\(\hat{x}_\epsilon\) , 可以近似认为满足分布\(p_{data}(x)\)。当\(t\) 不是从 0 开始而是从 \(\epsilon\) 开始,对应的加噪过程也有:

\[\begin{aligned}&\int_{\epsilon}^{T}dx_{t}=\int_{\epsilon}^{T}\sqrt{2t}dw_{t}\\&x_{T}-x_{\epsilon}=\int_0^T\sqrt{2t}dw_t-\int_0^\epsilon\sqrt{2t}dw_t\\&x_{T}-x_{\epsilon}=\mathcal{N}(0,T^2I)-\mathcal{N}(0,\epsilon^2I)\\&x_{T}-x_{\epsilon}=\mathcal{N}(0,(T^2-\epsilon^2)I)\end{aligned}\]

由于\(T>>\epsilon\), 所以先验分布仍然可以近似为 \(x_T\sim\mathcal{N}(0,T^2I)\) 。并且,类似的有:

\[x_t=x_\epsilon+\mathcal{N}(0,(t^2-\epsilon^2)I)\tag{6}\]

定义

一致性函数

在给定概率流常微分方程(Probability Flow Ordinary Differential Equation, PFODE)的情况下,一致性函数可表示为下面的形
式:

\[f(\mathbf{x}_t,t)=\begin{cases}\mathbf{x}_\epsilon,&t=\varepsilon\\f(\mathbf{x}_{t'},t'),&t\in(\varepsilon,T],\forall t'\in[\varepsilon,T]\end{cases}\tag{7}\]

一致性模型的学习目标就是在给定PFODE上,得到一个神经网络使得任意相邻两个时刻的输出尽可能相等,进而逼近一致性函数。

如下图所示:

image

神经网络\(f_{\theta}:\mathbb{R}^{N+1}\rightarrow\mathbb{R}^{N}\) 的形式为:

\[f_\theta(\mathbf{x}_t,t)\:=\:C_{\mathrm{skip}}(t)\mathbf{x}_t\:+\:C_{\mathrm{out}}(t)F_\theta(\mathbf{x}_t,t)\tag{8}\]

其中 \(C_{\mathrm{skip}}\)\(C_{\mathrm{out}}\)是两个关于时间\(t\)的可微函数,保证\(C_{\mathrm{skip}}(\epsilon)=1\)\(C_{\mathrm{out}}(\epsilon)=0\),满足一致性函数的要求。

def denoise(self, model, x_t, sigmas, **model_kwargs):
    import torch.distributed as dist

    if not self.distillation:
        c_skip, c_out, c_in = [
            append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)
        ]
    else:
        c_skip, c_out, c_in = [
            append_dims(x, x_t.ndim)
            for x in self.get_scalings_for_boundary_condition(sigmas)
        ]

    rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
    denoised = c_out * model_output + c_skip * x_t
    return model_output, denoised

可以发现,在实现时引入多了一个 \(C_{in}\),来限制输入图像,有点像LSTM的输入门。\(C_{skip}\)\(C_{out}\) 为人为设定,和时间相关,当\(t\)增加时,图像噪声水平增加,离采样终点越远远,\(C_{skip}\) 的值会下降,\(C_{out}\) 的值会上升,更少的输入信号 \(x_t\) 被保留,更多依靠模型去进行预测。相反,当 \(t\) 减少时,图像噪声水平下降,离采样终点越趋近,\(C_{skip}\) 的值会上升,\(C_{out}\) 的值会下降,更多的输入信号 \(x_t\) 被保留,更少依靠模型去进行预测。至于 \(C_{in}\),可以参考EDM

采样

一致性模型 \(f_\theta\) 训练好后,可以先从先验分布中采样\(\hat{x}_T\sim\pi(x)=\mathcal{N}(0,T^2I)\) ,然后通过\(\hat{x}_\epsilon=\boldsymbol{f}_\theta(\hat{x}_T,T)\)进行一步采样。
除此之外,还可以通过交替去噪和加噪 (在不同的时间点上进行一步加噪和一步去噪),提高样本
质量,即:

image

其中一步加噪使用的就是(6)

训练Loss

一致性模型采用了让PF-ODE相邻两个时间点模型输出值差距最小化的方式实现利用神经网络逼近一致性函数的目标,损失函数可以写为:

\[\mathcal{L}^N(\theta)=\mathbb{E}[\|f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})-f_\theta(\hat{\mathbf{x}}_{t_n},t_n)\|_2^2]\tag{9}\]

其中,\(N\) 表示时间点设置的数目,\(\hat{\mathbf{x}}_{t_n}\)是通过一种ODE求解器获得的上一个时刻 \(t_n\) 的图像,这里用\(t_n\) 表示第 \(n\) 个时间点对应的时间,\(\mathbf{x} \sim p_{\mathrm{data}}, \mathbf{x} _{t_{n+ 1}}\sim \mathcal{N} ( \mathbf{x} , t_{n+ 1}^2\mathbf{I} ) , n\sim \mathcal{U} [ 1, N- 1]\) 且为整数。\(N\) 是一个超参数,理论这个值越大,在PF-ODE上的路径点数目就会越多,两点之间越靠近,\(\hat{\mathbf{x}}_{t_n}\) 的求解会更准,模型精度越好。实际上当\(N\)足够大以后,这个数值对模型性能影响已经不敏感。作者在这里并没有直接使用上面的损失函数形式,而是将上式后一项的 \(f\) 的权重 \(\theta\) 换成了模型的指数滑动平均值 (Exponential Moving Average, EMA) \(\theta^-\)。根据EMA的性质,给定衰减系数\(0\leq\mu<1\),可得:

\[\theta^-\leftarrow\text{stopgrad}(\mu\theta^-+(1-\mu)\theta)\]

此时,损失函数可写为:

\[\mathcal{L}^N(\theta,\theta^-)=\mathbb{E}[\|f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})-f_{\theta^-}(\hat{\mathbf{x}}_{t_n},t_n)\|_2^2]\tag{10}\]

该式便是论文中损失函数的一种简化形式。采用EMA相当于用一个多轮加权平均的结果充当最终的模型,每一步不同的模型权重 \(\theta\) 对其影响较小,因此\(f_{\theta^-}\)又被称作"目标模型"。采用EMA的原因是可以提升训练过程的稳定性,提升一致性模型的效果。有了损失函数的形式了,就可以考虑如何训练模型了。论文中给定了两种方法, 一种是从已有模型切入,也即一致性蒸馏 (Consistency Distillation, CD),一种是从零开始训练一个新的一致性模型,也即一致性训练(Consistency Training, CT).

一致性蒸馏 (Consistency Distillation, CD)

一致性模型的训练可以从已有模型切入,比如采用score-based模型进行蒸馏。假设现在已经有一个模型\(s_\theta(\mathbf{x}_t,t)\), 很自然地,通过模型预测出一个score的值,再通过欧拉法采样一步,就可以获得在PF-ODE路径上相邻点的位置。根据PF-ODE形式,很显然有:

\[\hat{\mathbf{x}}_{t_n}=\mathbf{x}_{t_{n+1}}-(t_n-t_{n+1})t_{n+1}s_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})\tag{11}\]

代入(10)损失函数中可得CD的损失函数形式为:

\[\mathcal{L}_{CD}^N(\theta,\theta^-) = \\ \mathbb{E}[\|f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})-f_{\theta^-}(\mathbf{x}_{t_{n+1}}-(t_n-t_{n+1})t_{n+1}s_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),t_n)\|_2^2]\tag{12}\]

论文中给出了CD过程的算法步骤如下,结合损失函数来看不难理解。首先从数据中采样样本,均匀采样时间点\(n\),根据时间加噪声。加完噪声后,采用ODE数值求解器获得上一个时间点的图像,紧接着计算损失函数并进行梯度反传。

image

其中,\(\Phi\) 表示模型学习的"司机”, 可以是score,也可以是速度场等ODE迭代必须元素。\(\lambda(\cdot)\in\mathbb{R}^+\)是权重函数,作者设置恒等于 1, \(d(\cdot,\cdot)\) 是衡量两个输入距离的指标,作者考虑了 L1、L2 和 LPIPS。需要注意的是,每一步需要同步更新EMA权重\(\theta^-\)的值。

一致性训练(Consistency Training, CT)

一致性模型难道一定要背靠一个现成的扩散模型吗?也可以从零开始训练一个一致性模型。可以发现,没有现成扩散模型的最大问题是 \(\Phi\) 没了,没办法进行ODE的数值求解了。作者就给定了一个score-based模型\(s_θ(x_t,t)\)预测结果\(\nabla \log p_t(x_t)\)的平替,用这个平替让模型进行学习,进而替代\(s_\theta(x_t,t)\)。这个平替需要至少满足以下两个条件:

  1. \(\nabla \log p_t(x_t)\)的无偏估计
  2. 需要在训练过程中作为金标准,因此需要可以计算获得

基于上述两个条件,作者给出了平替的形式为:

\[ \nabla_{x_t} \log p_t(x_t) = -\mathbb{E}\left[\frac{x_t - x}{t^2}|x_t\right] \tag{13}\]

其中,\(x_t \sim \mathcal{N}(x,t^2\mathbb{I})\)\(x \sim p_{data}\),可以发现。

论文中推导的前几步可能稍微难理解一点,在这里给大家全部补上。回忆一下概率论的基本概念,根据边缘概率密度与联合概率密度的关系,可得如下的等式:

\[p_t(x_t) = \int p(x_t,x)dx = \int p(x_t|x)p_{data}(x)dx\]

接下来就是公式推导,公式推导的目的就是从score的定义出发,探索到底有没有一个关于score的无偏估计量。总体而言,宋博士在这里的推导过程相当清晰,几乎没有跳步,仅在开始两步稍微有点难度,在这里给大家一一分解。根据score的定义,\(\log p_t(x_t)\)关于\(x_t\)求导,实际上是一个复合函数求导!也即先对log求导,再对\(p_t(x_t)\)求导!心中牢记复合函数求导,则推导过程并不难理解,如下所示:

\[\begin{aligned} \nabla_{x_t} \log p_t(x_t) &= \frac{1}{p_t(x_t)} \cdot \nabla_{x_t}p_t(x_t) \quad \text{复合函数求导} \ \\&= \frac{\nabla_{x_t} \int p(x_t|x)p_{data}(x)dx}{p_t(x_t)} \ \\&= \frac{\int \nabla_{x_t}p(x_t|x)p_{data}(x)dx}{p_t(x_t)} \quad \text{莱布尼兹法则} \ \\&= \frac{\int \nabla_{x_t} \log p(x_t|x)p(x_t|x)p_{data}(x)dx}{p_t(x_t)} \quad \text{反用复合函数求导} \ \\&= \int \nabla_{x_t} \log p(x_t|x)\frac{p(x_t|x)p_{data}(x)}{p_t(x_t)}dx \ \\&= \int \nabla_{x_t} \log p(x_t|x)p(x|x_t)dx \quad \text{贝叶斯公式} \ \\&= \mathbb{E}{x_t\sim\mathcal{N}(x,t^2\mathbb{I}),x\sim p{data}}[\nabla_{x_t} \log p(x_t|x)|x_t] \quad \text{条件期望定义} \ \\&= -\mathbb{E}{x_t\sim\mathcal{N}(x,t^2\mathbb{I}),x\sim p{data}}\left[\frac{x_t-x}{t^2}|x_t\right] \quad \text{代入正态分布公式} \end{aligned}\]

通过推导可以发现,这个平替确实是score的无偏估计。

既然无偏估计出现了,对于原始损失函数,只需要把原来依靠模型预测值\(s_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})\)预测的\(\hat{\mathbf{x}}_{t_n}\)直接写为平替形式
也即

\[\begin{aligned}&\hat{\mathbf{x}}_{t_n}=\mathbf{x}_{t_{n+1}}-(t_n-t_{n+1})t_{n+1}s_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})\\&\hat{\mathbf{x}}_{t_n}=\mathbf{x}_{t_{n+1}}+(t_n-t_{n+1})t_{n+1}\frac{\mathbf{x}_{t_{n+1}}-\mathbf{x}}{t_{n+1}^2}=\mathbf{x}_{t_{n+1}}+(t_n-t_{n+1})\frac{\mathbf{x}_{t_{n+1}}-\mathbf{x}}{t_{n+1}}\end{aligned}\tag{14}\]

所以,CT的损失函数形式为:

\[\mathcal{L}_{CT}^N(\theta,\theta^-)=\mathbb{E}\left[\left\|f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})-f_{\theta^-}\left(\mathbf{x}_{t_{n+1}}+(t_n-t_{n+1})\frac{\mathbf{x}_{t_{n+1}}-\mathbf{x}}{t_{n+1}},t_n\right)\right\|_2^2\right]\]

论文给出了基于CT损失函数的模型训练算法步骤,可以发现相比CD来说,主要不同点为相邻时刻(也即上一时刻)的图像直接通过原始图像加噪而非ODE迭代获得 (与(14)等价)。

image

如何证明这个估计是正确有效的?假如我们已经有了一个ground truth score function
\(\bold s_{\phi^*}(\bold x,t)=\nabla\log p_t(\bold x_t)\),可以证明这个估计的loss与给定ground truth score function的loss之间的误差趋于0,即\(\mathcal L_{CD}^N(\theta,\theta^-;\phi^*) - \mathcal L_{CT}^N(\theta,\theta^-)=o(\Delta t)\)。作者通过泰勒展开证明了这个结论。

在实验中,步数\(N\)随训练逐渐增加,可以加快收敛。

image

效果上,目前diffusion > CD > diffusion distillation > CT,(GAN的效果不太好比较,在有些数据集上,经过这么多年充分调教的GAN性能甚至再次反超diffusion,有的数据集上则不如CT)。作为单步方法的CD和CT打不过diffusion,其实可以理解,一方面这one-one mapping本身非常复杂,还是比较难拟合的;另一方面,consistency models训练方法是约束一条ODE轨迹上的相邻点输出相同,模型拟合不好的话,每一步的误差都会累积下来。

不过consistency model才刚刚提出不久,在训练方法上还有很多待研究的技术。而作为一种单步生成方法,相对于diffusion在性能上有着显著优势。

连续一致性模型

损失函数推导:以一致性蒸馏中 \(\theta^- = \theta\)为例

上述的一致性模型是基于离散时间序列 \(\{t_0, \dots, t_N\}\),最典型的体现就是得用一个 ODE 求解器从 \(\mathbf{x}_{t_{n+1}}\)获得\(\mathbf{\hat{x}}_{t_{n}}\)。ODE 求解器一定存在由于离散时间导致的截断误差,也就从理论上无法达到最优。离散到连续的转化主要体现在 ODE 求解器的迭代公式,在离散情况下, \(d t\)可以近似为 \(\Delta t = t_{n+1} - t_n\), 在连续情况下我们有一个关于 \(u\) 的连续时间函数 \(t = \tau(u)\),所以 \(\Delta t \approx \frac{d\tau(u)}{du}\Delta u = \tau'(u)\Delta u\)。 其中,\(\Delta u \to 0\),也即

  • 离散情形:
\[\mathbf{\hat{x}}_{t_{n}} = \mathbf{x}_{t_{n+1}} + t_{n+1}s_\phi(\mathbf{x}_{t_{n+1}}, t_{n+1})(t_{n+1} - t_n)\tag{15}\]
  • 连续情形:
\[ \mathbf{\hat{x}}_{t_{n}} = \mathbf{x}_{t_{n+1}} + t_{n+1}s_\phi(\mathbf{x}_{t_{n+1}}, t_{n+1})\tau'(u)\Delta u\tag{16}\]

上述看似简单的改变会导致一致性模型学习目标(损失函数)的改变,一致性模型论文分了很多种情况进行推导讨论。让我们先考虑最简单的情况,先认为 \(\theta^- = \theta\),也即不使用 EMA 权重,考虑一致性蒸馏场景。\(\tau(u_n) = \tau\left(\frac{n-1}{N-1}\right), \quad n \in [1, N]\)\(n \in \mathbb{Z}, \quad u_n = \frac{n-1}{N-1}, \Delta u=\frac{1}{N-1}\),连续形式的损失函数可以改写为:

\[ \mathbb{E}\left[\lambda(t_n)d\left(f_\theta(\mathbf{x}_{t_{n+1}}, t_{n+1}), f_\theta(\mathbf{x}_{t_{n+1}}, t_{n+1}) + t_{n+1}s_\phi f_\theta(\mathbf{x}_{t_{n+1}}, t_{n+1})\tau'(u_n)\Delta u, t_n\right)\right]\tag{17}\]

(17)是以一种“如连续”的状态,你说它是连续的,它毕竟没有显式的离散两个时间点的差了;你说它不是连续的,它毕竟还存在一个好像不那么和谐的 \(\Delta u\),而且 \(\tau'(u)\) 还是存在的。(17)虽然形式有了,但很难用代码实现。现在看来似乎没有什么办法了,毫无思路。毫无思路的时候,记住一个口诀:泰勒救我! 光有泰勒显然还不够,我们回想一致性模型的学习目标是希望相邻两个时间点的输出尽可能的接近,最简单的想法就是,先无脑求一下下面这个式子:

\[f_\theta(\mathbf{\hat{x}}_{t_n}, t_n) - f_\theta(\mathbf{x}_{t_{n+1}}, t_{n+1})\tag{18}\]

很显然,直接求没法求,这个时候泰勒展开就能出场了:

\[\begin{aligned}&f_{\theta}(\hat{\mathbf{x}}_{t_{n}},t_{n})-f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})\\&=f_{\theta}(\mathbf{x}_{t_{n+1}}+t_{n+1}s_{\phi}(\mathbf{x}_{t{n+1}},t_{n+1})\tau^{\prime}(u_{n})\Delta u,t_{n}))-f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})\\&=f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})+t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1})\tau^{\prime}(u_{n})\Delta u\\&+\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}}\underbrace{(t_{n}-t_{n+1})}_{-\tau^{\prime}(u_{n})\Delta u}-f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})+\underbrace{O((\Delta u)^{2})}_{\text{Lagrauge remainder}}\\&=t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1})\tau^{\prime}(u_{n})\Delta u \\&-\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}}\tau^{\prime}(u_{n})\Delta u+O((\Delta u)^{2}) \end{aligned}\tag{19}\]

(19)暂时现放着,后面有用。还有一个式子似乎可以用泰勒展开,那就是距离度量函数\(d\),也即:

\[d(f_\theta(\mathbf{x}_{t_{n+1}}),t_{n+1}),f_\theta(\hat{\mathbf{x}}_{t_n},t_n))\tag{20}\]

从始至终,影响我们无法进行实现连续一致性模型的"罪魁祸首"就是这个\(\hat{\mathbf{x}}_{t_n}\), 既然看它不爽,那我
们对(20)的第二个输入\(f_\theta(\hat{\mathbf{x}}_n,t_n)\)\(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})\) 处进行泰勒展开。需要注意,这里实际上是一个"一元”函数的泰勒展开,有:

\[\begin{aligned}&d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_\theta(\hat{\mathbf{x}}_{t_n},t_n)) \\&=d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})) \\&+\frac{\partial d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))}{\partial f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})}(f_\theta(\hat{\mathbf{x}}_{t_n},t_n)-f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))\\&+\frac12(f_\theta(\hat{\mathbf{x}}_{t_n},t_n)-f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))^T\mathbf{G}(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))(f_\theta(\hat{\mathbf{x}}_{t_n},t_n)-f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))+O(|\Delta u|^3)\end{aligned} \tag{21}\]

这里引入了一个新的矩阵\(\mathbf{G}(\mathbf{x})\),它的具体形式如下:

\[[\mathbf{G}(\mathbf{x})]_{ij}:=\frac{\partial^2d(\mathbf{x},\mathbf{y})}{\partial y_i\partial y_j}\Big|_{\mathbf{y}=\mathbf{x}}\]

这个矩阵实际上就是黑塞 (Hessian) 矩阵,记录了一个标量输出的函数与输入向量所有元素可能的二阶偏导数。注意(21),有一些项是可以化简的:

  1. 对于距离度量\(d(\mathbf{x},\mathbf{y})\),在同一位置的两点距离为0,也即\(d(\mathbf{x},\mathbf{y})|_{\mathbf{y}=\mathbf{x}}=0\),所以有:
\[d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))=0\tag{22}\]
  1. 对于距离度量\(d(\mathbf{x},\mathbf{y})\), 当两点距离位置相同时的梯度为0,也即\(\nabla_\mathbf{y}d(\mathbf{x},\mathbf{y})|_\mathbf{y=\mathbf{x}}\equiv0\)。所以有:
\[\frac{\partial d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))}{\partial f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1})}=0\tag{23}\]

有了(22)(23),实际上(21)就只剩下二阶导数和余项了,也即:

\[\begin{aligned}&d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_\theta(\hat{\mathbf{x}}_{t_n},t_n))\\&=\frac12(f_\theta(\hat{\mathbf{x}}_{t_n},t_n)-f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))^T\mathbf{G}(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))(f_\theta(\hat{\mathbf{x}}_{t_n},t_n)-f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}))+O(|\Delta u|^3)\end{aligned} \tag{24}\]

回想一致性模型的损失函数,代入(24),可得:

\[\begin{aligned}&\mathbb{E}_{\mathbf{x}\sim p_{\mathrm{data}},u\sim\mathcal{U}[0,1],\mathbf{X}_{t}\sim\mathcal{N}(\mathbf{x},t^{2}\mathbf{I})}[\lambda(t_{n})d(f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}),f_{\theta}(\hat{\mathbf{x}}_{t_{n}},t_{n}))]\\&=\mathbb{E}_{\mathbf{x},\mathbf{u},\mathbf{x}_{t}}\left[\lambda(t_{n})\frac{1}{2}(f_{\theta}(\hat{\mathbf{x}}_{t_{n}},t_{n})-f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}))^{T} \mathbf{G}(f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}))(f_{\theta}(\hat{\mathbf{x}}_{t_{n}},t_{n})-f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}))\right]\\&+\mathbb{E}[O(|\Delta u|^3)]\end{aligned}\tag{25}\]

到此,(19)就派上用场了,将它代入(25)中,有:

\[\begin{aligned}&\mathbb{E}_{\mathbf{x}\sim p_{\mathrm{data}},u\sim\mathcal{U}[0,1],\mathbf{X}_{t}\sim\mathcal{N}(\mathbf{x},t^{2}\mathbf{I})}[\lambda(t_{n})d(f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}),f_{\theta}(\hat{\mathbf{x}}_{t_{n}},t_{n}))] \\ &=\mathbb{E}_{\mathbf{x},\mathbf{u},\mathbf{x}_{t}}\big[\lambda(t_{n})\frac{1}{2}\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1})\tau^{\prime}(u_{n})\Delta u -\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}}\tau^{\prime}(u_{n})\Delta u \right)^{T} \\ &\mathbf{G}(f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}))\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1})\tau^{\prime}(u_{n})\Delta u -\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}}\tau^{\prime}(u_{n})\Delta u \right)\big] \\&+\mathbb{E}[O(|\Delta u|^3)] \\&=\frac{(\Delta u)^2}{2}\mathbb{E}_{\mathbf{x},\mathbf{u},\mathbf{x}_{t}}\big[\lambda(t_{n})\tau^{\prime}(u_{n})^2\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1}) -\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}} \right)^{T} \\ &\mathbf{G}(f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}))\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1})-\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}}\right)\big]+\mathbb{E}[O(|\Delta u|^3)] \end{aligned}\tag{26}\]

因为\(\Delta u=\frac1{N-1}\), 观察(26),发现这个\((\Delta u)^2\)是一个微小量,很可能就是未来两边同时取极限能用到,不妨先移动到公式左边,也即等式两边同时乘以\(\frac1{(\Delta u)^2}\),可得:

\[\begin{aligned}&\frac1{(\Delta u)^2}\mathbb{E}_{\mathbf{x}\sim p_{\mathrm{data}},u\sim\mathcal{U}[0,1],\mathbf{X}_{t}\sim\mathcal{N}(\mathbf{x},t^{2}\mathbf{I})}[\lambda(t_{n})d(f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}),f_{\theta}(\hat{\mathbf{x}}_{t_{n}},t_{n}))] \\&=\frac{(\Delta u)^2}{2}\mathbb{E}_{\mathbf{x},\mathbf{u},\mathbf{x}_{t}}\big[\frac{\lambda(t_{n})}{[(\tau^{-1})^{\prime}(t_{n})]^2}\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1}) -\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}} \right)^{T} \\ &\mathbf{G}(f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1}))\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial\mathbf{x}_{t_{n+1}}}s_{\phi}(\mathbf{x}_{t_{n+1}},t_{n+1})-\frac{\partial f_{\theta}(\mathbf{x}_{t_{n+1}},t_{n+1})}{\partial t_{n+1}}\right)\big]+\mathbb{E}[O(|\Delta u|^3)] \end{aligned}\tag{27}\]

其中\(\tau^{-1}(t_n)=u_n\)\(\tau(u_n)=t_n\)的反函数,因此具备反函数导数性质,也即\(\tau^{\prime}(u_n)=\frac1{(\tau^{-1})^{\prime}(t_n)}\)。根据(27),结合一致性模型损失函数定义,有:

\[\begin{aligned}(N-1)^2\mathcal{L}_{CD}^N(\theta,\theta;\phi)&=\frac{1}{(\Delta u)^2}\mathcal{L}_{CD}^N(\theta,\theta;\phi)\\&=\frac{1}{(\Delta u)^2}\mathbb{E}_{\mathbf{x},u,\mathbf{x}_t}[\lambda(t_n)d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_\theta(\hat{\mathbf{x}}_{t_n},t_n))]\end{aligned}\tag{28}\]

连续一致性模型的连续来自于对(28)两边同时取极限,也即\(N\to\infty\)或者\(\Delta u\to0\)。当\(\Delta u\to0\)的时候,有 \(t_{n+1}=\tau(u_n+\Delta u)\to\tau(u_n)=t_n\), 也即\(t_{n+1}\)可以用\(t_n\)替代了。进一步的,连续时间下,每个时间点都可以取得到,也就不存在第\(n\)个时间,\(t_n\)可以直接去掉下标\(n\)变为 \(t\)。此外,(27)中的\(\mathbb{E}[O(|\Delta u|)]\to0\),这项可以忽略了。作者定义(27)\(N\to\infty\) (\(\Delta u\to 0\)) 条件下的形式即为\(\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)\)

\[\begin{aligned}&{L}_{CD}^N(\theta,\theta;\phi) \\&=\frac{1}{2}\mathbb{E}_{\mathbf{x},\mathbf{u},\mathbf{x}_{t}}\big[\frac{\lambda(t)}{[(\tau^{-1})^{\prime}(t)]^2}\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t},t)}{\partial\mathbf{x}_{t}}s_{\phi}(\mathbf{x}_{t},t) -\frac{\partial f_{\theta}(\mathbf{x}_{t},t)}{\partial t} \right)^{T} \\ &\mathbf{G}(f_{\theta}(\mathbf{x}_{t},t))\left( t_{n+1}\frac{\partial f_{\theta}(\mathbf{x}_{t},t)}{\partial\mathbf{x}_{t}}s_{\phi}(\mathbf{x}_{t},t)-\frac{\partial f_{\theta}(\mathbf{x}_{t},t)}{\partial t}\right)\big] \end{aligned}\tag{29}\]

连续一致性模型定理

一致性蒸馏

(29)就是连续一致性模型的一种损失函数形式,根据(29)宋博士提出了如下定理

定理1.\(t_n=\tau\left(\frac{n-1}{N-1}\right)\),其中\(n\in[1,N]\)且为整数,\(\tau(\cdot)\)是一个严格单调函数,满足\(\tau(0)=\varepsilon\),\(\tau(1)=T\)。假设 \(\tau\) 在[0,1]区间连续可微,\(d\) 是一个三阶连续可导函数且三阶导数有界,\(f_\theta\) 二阶连续可导且一阶和二阶导数均有界。又假设权重函数\(\lambda(\cdot)\) 是有界的,满足\(\sup_{\mathbf{x},t\in[\varepsilon,T]}\|s_\phi(\mathbf{x},t)\|_2<\infty\)。如果在一致性蒸馏中使用Euler ODE求解器,有:
\[\lim_{N\to\infty}(N-1)^2\mathcal{L}_{CD}^N(\theta,\theta;\phi)=\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)\tag{30}\]

\(\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)\)定义为(29)的形式,\(\mathbf{x} \sim p_\text{data}, u\sim \mathcal{U} [ 0, 1] , t= \tau ( u) , \mathbf{x} _t\sim \mathcal{N} ( \mathbf{x} , t^2\mathbf{I} )\)

作者在定理1 的基础上给出了三点"评论”,咱们一个一个来看。首先评论1,解释了采用其他求解器可能获得类似结论。

评论1.1. 尽管定理1假设了使用Euler求解器简化计算,但是我们相信更通用的求解器也能推导出类似的结果,因为所有的ODE求解器在$N\to\infty$的条件下性能类似。我们将定理1的更通用证明留作未来的工作。

评论2给出了一种特殊形式,当距离度量\(d(\mathbf{x},\mathbf{y})=\|\mathbf{x}-\mathbf{y}\|_{2}^{2}\) 时,损失函数(29)的形式可以得到简化,简化后的形式可以用雅可比向量积求出。

评论1.2. 定理1 意味着可以通过最小化\(\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)\) 来训练一致性模型。特别的,当\(d(\mathbf{x},\mathbf{y})=\|\mathbf{x}-\mathbf{y}\|_2^2\)时,有:
\[\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)=\mathbb{E}\left[\frac{\lambda(t)}{[(\tau^{-1})'(t)]^2}\left\|t\frac{\partial f\theta(\mathbf{x}_t,t)}{\partial\mathbf{x}t}s_\phi(\mathbf{x}_t,t)-\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial t}\right\|_2^2\right]\tag{31}\]

然而,要求出这个连续时间损失函数需要计算雅可比向量积,如果深度学习框架不支持forward-mode (梯度前向传播)自动微分,这会是一个很慢而费力的过程。

评论3说明(31)为0的时候,确实达到了\(f_\theta(\mathbf{x},t)\)完美符合一致性函数的要求

评论1.3. 如果\(f_\theta(\mathbf{x}_t,t)\)在给定\(s_\phi(\mathbf{x}_t,t)\)的前提下达到了经验PFODE在给定\(s_\phi(\mathbf{x}_t,t)\)的一致性函数最优解,有
\[t\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial\mathbf{x}_t}s\phi(\mathbf{x}_t,t)-\frac{\partial f\theta(\mathbf{x}_t,t)}{\partial t}\equiv0\tag{32}\]

也即\(\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)=0\)

评论1.3的内容是可以证明的,考虑到一致性函数的特点\(f_\theta(\mathbf{x}_t,t)\equiv\mathbf{x}_e\) 对于所有的\(t\in[\varepsilon,T]\), 从这个点来出发,结合对时间求导,即可获得 (32),有:

\[f_\theta(\mathbf{x}_t,t)\equiv\mathbf{x}_\varepsilon\]

两边对时间求导,\(\mathbf{x}_\varepsilon\)是原始图像与时间无关,可得:

\[\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial\mathbf{x}_t}\frac{\mathrm{d}\mathbf{x}_t}{\mathrm{d}t}+\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial t}\equiv0\]

不要忘记\(\frac{\mathrm{d}\mathbf{x}_t}{\mathrm{d}t}=-ts_\phi(\mathbf{x}_t,t)\), 即(5)

\[\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial\mathbf{x}_t}[-ts_\phi(\mathbf{x}_t,t)]+\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial t}\equiv0\]

两边乘以-1, 初等运算可得:

\[t\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial\mathbf{x}t}s_\phi(\mathbf{x}_t,t)-\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial t}\equiv0\]

上述证明过程说明了当连续一致性模型的损失函数达到0的时候,该模型完全符合一致性函数的特性。

然而,当距离度量是L1的时候,二阶偏导怎么求都是0,也就导致(21)对应的矩阵G为全0矩阵,显然没法计算了。因此,面对L1举例度量,作者给出了另外一个定理证明方法很简单,代入(19)整理下即可,在此不再赘述,直接上定理。

定理2.\(t_n=\tau\left(\frac{n-1}{N-1}\right)\),其中\(n\in[1,N]\)且为整数,\(\tau(\cdot)\)是一个严格单调函数,满足$\tau(0)=\varepsilon$,\(\tau(1)=T\)。假设 \(\tau\) 在[0,1]区间连续可微,\(d\) 是一个三阶连续可导函数且三阶导数有界,\(f_\theta\) 二阶连续可导且一阶和二阶导数均有界。又假设权重函数\(\lambda(\cdot)\) 是有界的,满足\(\sup_{\mathbf{x},t\in[\varepsilon,T]}\|s_\phi(\mathbf{x},t)\|_2<\infty\)。如果在一致性蒸馏中使用Euler ODE求解器,距离度量\(d(\mathbf{x},\mathbf{y})=\|\mathbf{x}-\mathbf{y}\|_1\), 有:
\[\lim_{N\to\infty}(N-1)\mathcal{L}_{CD}^N(\theta,\theta;\phi)=\mathcal{L}_{CD,\ell_1}^\infty(\theta,\theta;\phi)\tag{33}\]

\(\mathcal{L}_{CD,\ell_1}^\infty(\theta,\theta;\phi)\) 的形式为:
\[\mathcal{L}_{CD,\ell_1}^\infty(\theta,\theta;\phi):=\mathbb{E}\left[\frac{\lambda(t)}{(\tau^{-1})'(t)}\left\|t\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial\mathbf{x}_t}s_\phi(\mathbf{x}_t,t)-\frac{\partial f_\theta(\mathbf{x}_t,t)}{\partial t}\right\|_1\right]\tag{34}\]

其中,\(\mathbf{x} \sim p_{\mathrm{data}}\),\(u\sim\mathcal{U}[0,1],t=\tau(u),\mathbf{x}_t\sim\mathcal{N}(\mathbf{x},t^2\mathbf{I})\)

同样,定理2作者也给出了一些评论,咱们一个一个来看:

评论2.1. 按照(34)的损失函数训练模型确实能够达到一致性函数拟合的最优解,也即当且仅当 \(f_\theta(\mathbf{x}_t,t)=\mathbf{x}_c\) 对所有 \(\mathbf{x}_t\in\mathbb{R}^d\)\(t\in[\varepsilon,T]\) 都成立时,\(\mathcal{L}_{CD},\ell_1^{\infty}(\theta,\theta;\phi)=0\)

这个证明过程和评论1.3完全一致,不再赘述。

第二种情况,类似于离散一致性模型,令\(\theta ^- =stopgrad( \theta )\),也即复制一个原始模型但不计算梯度,注意这里并没有采用指数滑动平均值模型,也可以理解为在指数滑动平均值的公式中取\(\mu=0\)。经过类似的推导过程,有如下结论:

定理3.令 \(t_n=\tau\left(\frac{n-1}{N-1}\right)\) ,其中 \(n\in[1,N]\) 且为整数,\(\tau(\cdot)\) 是一个严格单调函数,满足 $\tau(0)=\varepsilon$,\(\tau(1)=T\)。假设 \(\tau\) 在[0,1]区间连续可微,\(d\) 是一个三阶连续可导函数且三阶导数有界,\(f_\theta\) 二阶连续可导且一阶和二阶导数均有界。又假设权重函数\(\lambda(\cdot)\) 是有界的,满足\(\sup_{\mathbf{x},t\in[\varepsilon,T]}\|s_\phi(\mathbf{x},t)\|_2<\infty\)。如果在一致性蒸馏中使用Euler ODE求解器,令\(\theta^-=stopgrad(\theta)\),则有:
\[\lim_{N\to\infty}(N-1)\nabla_\theta\mathcal{L}_{CD}^N(\theta,\theta^-;\phi)=\nabla_\theta\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\tag{35}\]

\(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\) 的形式为:
\[\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi):= > \mathbb{E}\left[\frac{\lambda(t)}{(\tau^{-1})^{\prime}(t)}f_\theta(\mathbf{x}_t,t)^T > \mathbf{H}(f_{\theta^-}(\mathbf{x}_t,t))\left(\frac{\partial f_{\theta^-}(\mathbf{x}_t,t)}{\partial t}-t\frac{\partial f_{\theta^-}(\mathbf{x}_t,t)}{\partial\mathbf{x}_t}s_\phi(\mathbf{x}_t,t)\right)\right]\]

其中,\(\mathbf{x} \sim p_\text{data}\), \(u\sim\mathcal{U}[0,1],t=\tau(u),\mathbf{x}_t\sim\mathcal{N}(\mathbf{x},t^2\mathbf{I})\)

证明方法总结起来还是四个字“泰勒救我”,论文写的也十分详尽,再此不再赘述。可以发现这里多了一个矩阵\(\mathbf{H}(\mathbf{x})\),它和\(\mathbf{G(x)}\)类似,也是一个黑塞矩阵,形式如下:

\[[\mathbf{H}(\mathbf{x})]_{ij}:=\left.\frac{\partial^2d(\mathbf{y},\mathbf{x})}{\partial y_i\partial y_j}\right|{\mathbf{y}=\mathbf{x}}\]

它和\(\mathbf{G}(\mathbf{x})\) 矩阵的唯一区别就是偏导项中的\(d(\cdot,\cdot)\) 交换了两个入参的位置。同样,作者基于定理3同样给出了几个评论

评论3.1.当\(d(\mathbf{x},\mathbf{y})=\|\mathbf{x}-\mathbf{y}\|_2^2\) 时,"伪目标"\(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\)可以化简为:
\[\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)=2\mathbb{E}\left[\frac{\lambda(t)}{(\tau^{-1})'(t)}f_\theta(\mathbf{x}_t,t)^T\left(\frac{\partial f_{\theta^-}(\mathbf{x}_t,t)}{\partial t}-t\frac{\partial f_{\theta^-}(\mathbf{x}_t,t)}{\partial\mathbf{x}_t}s_\phi(\mathbf{x}_t,t)\right)\right]\]

评论3.1也是基于L2范数做的简化,推导过程也很简单,直接计算即可,在此省略。

评论3.2.定理3的中的损失函数 \(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\) 并不是真正的训练目标,只有它的梯度具有意义,也即在训练过程中不能通过监控 \(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\) 的值决定模型训练情况。然而,可以通过对\(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\) 使用梯度下降法从预训练的扩散模型中蒸馏出一致性模型。显然,这个损失函数并不是一个典型的训练目标,所以在一致性蒸馏中成为"伪目标”。

评论3.2说明了当 \(\theta^-=\text{stopgrad}(\theta)\) 时,虽然仍然对 \(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\) 使用梯度下降法,但目标是使得\(\nabla_\theta\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)=\mathbf{0}\),也很显然当 \(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\) 接近 0 的时候,它的梯度向量的每个分量也会很小,所以采用正常的梯度下降法是仍然可以训练一致性模型。需要注意的是,\(\mathcal{L}_{CD}^{\infty}(\theta,\theta^{-};\phi)\) 本身没有意义,不能通过查看\(\mathcal{L}_{CD}^{\infty}(\theta,\theta^{-};\phi)\) 的值来判断训练情况。

评论3.3.如果 \(f_\theta(\mathbf{x}_t,t)\) 在给定 \(s_\phi(\mathbf{x}_t,t)\) 的前提下达到了经验PFODE在给定 \(s_\phi(\mathbf{x}_t,t)\) 的一致性函数最优解,有\(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)=0\)\(\nabla_\theta\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)=\mathbf{0}\) 。然而,反过来并不成立。\(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)=0\)\(\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)=0\) 的区别在于\(\mathcal{L}_{CD}^\infty(\theta,\theta;\phi)=0\) 才是真正的损失函数。

一致性训练

一致性训练的定理就是将\(s_\phi(\mathbf{x}_t,t)\) 用它的平替进行替换,也即:

\[s_\phi(\mathbf{x}_t,t)\approx-\frac{\mathbf{x}_t-\mathbf{x}}{t^2}\]

对于\(\theta^-=\text{stopgrad}(\theta)\) 的情况,有如下定理:

定理4.令 $t_n=\tau\left(\frac{n-1}{N-1}\right)$,其中\(n\in[1,N]\) 且为整数,\(\tau(\cdot)\) 是一个严格单调函数,满足 $\tau(0)=\varepsilon$,\(\tau(1)=T\)。假设 \(\tau\) 在[0,1]区间连续可微,\(d\) 是一个三阶连续可导函数且三阶导数有界,\(f_\theta\) 二阶连续可导且一阶和二阶导数均有界。又假设权重函数\(\lambda(\cdot)\) 是有界的,满足\(\mathbb{E}[\|\nabla\log p_{t_n}(\mathbf{x}_{t_n})\|_2^2]<\infty, \sup_{\mathbf{x},t\in[\varepsilon,T]}\|s_\phi(\mathbf{x},t)\|_2<\infty\) , \(\phi\) 表示满足\(s_\phi(\mathbf{x},t)\equiv\nabla\log p_t(\mathbf{x})\) 的扩散模型参数。如果在一致性训练中使用Euler ODE求解器,且\(\theta ^- =stopgrad( \theta )\),则有:
\[\lim_{N\to\infty}(N-1)\nabla_\theta\mathcal{L}_{CD}^N(\theta,\theta^-;\phi)=\lim_{N\to\infty}(N-1)\nabla_\theta\mathcal{L}_{CT}^N(\theta,\theta^-)=\nabla\theta\mathcal{L}_{CT}^\infty(\theta,\theta^-)\]

\(\mathcal{L}_{CT}^\infty(\theta,\theta^-)\) 的形式为:
\[\mathcal{L}_{CT}^{\infty}(\theta,\theta^{-}):=\mathbb{E}\left[\frac{\lambda(t)}{(\tau^{-1})'(t)}f_{\theta}(\mathbf{x}_{t},t)^{T}\mathbf{H}(f_{\theta^{-}}(\mathbf{x}_{t},t))\left(\frac{\partial f_{\theta^{-}}(\mathbf{x}_{t},t)}{\partial t}+\frac{\partial f_{\theta^{-}}(\mathbf{x}_{t},t)}{\partial\mathbf{x}_{t}}\cdot\frac{\mathbf{x}_{t}-\mathbf{x}}{t}\right)\right]\]

其中,\(\mathbf{x} \sim p_{\mathrm{data}}\),\(u\sim\mathcal{U}[0,1],t=\tau(u),\mathbf{x}_t\sim\mathcal{N}(\mathbf{x},t^2\mathbf{I})\)

证明过程也是四个字”泰勒救我”,建议看论文逐步跟着推导,与前面推导过程类似,在此不再赘述。对于定理4作者也给了评论,咱们一个一个来看:

评论4.1. \(\mathcal{L}_{CT}^\infty(\theta,\theta^-)\) 不依赖于已有扩散模型 \(\phi\), 因此可以在没有任何预训练扩散模型的基础上进行优化。

评论4.1就是直观的解释,一致性训练不依赖于任何已有扩散模型如\(s_\phi(\mathbf{x}_t,t)\),在连续一致性模型仍然成立。

评论4.2.\(d(\mathbf{x},\mathbf{y})=\|\mathbf{x}-\mathbf{y}\|_2^2\)时,连续时间一致性训练目标为:
\[ \mathcal{L}_{CT}^{\infty}(\theta,\theta^{-})=2\mathbb{E}\left[\frac{\lambda(t)}{(\tau^{-1})'(t)}f_{\theta}(\mathbf{x}_{t},t)^{T}\left(\frac{\partial f_{\theta^{-}}(\mathbf{x}_{t},t)}{\partial t}+\frac{\partial f_{\theta^{-}}(\mathbf{x}_{t},t)}{\partial\mathbf{x}_{t}}\cdot\frac{\mathbf{x}_{t}-\mathbf{x}}{t}\right)\right]\]

评论4.2也属于"日常评论"了,推导过程也十分简单,不在此进行推导。

评论4.3.与 \(\mathcal{L}_{CD}^\infty(\theta,\theta^-;\phi)\) 类似,\(\mathcal{L}_{CT}^\infty(\theta,\theta^-)\) 也是一个"伪目标”,而并不是真正的训练目标,只有它的梯度具有意义,也即在训练过程中不能通过监控 \(\mathcal{L}_{CT}^\infty(\theta,\theta^-)\) 的值决定模型训练情况。然而,可以通过对 \(\mathcal{L}_{CT}^\infty(\theta,\theta^-)\)使用梯度下降法以数据驱动的方式训练得到一致性模型 \(f_\theta(\mathbf{x}_t,t)\) 。此外,如果 \(f_\theta(\mathbf{x}_t,t)\) 就是PFODE所对应的真实一致性函数,则\(\mathcal{L}_{CT}^\infty(\theta,\theta^-)=0\)\(\nabla_\theta\mathcal{L}_{CT}^\infty(\theta,\theta^-)=\mathbf{0}_{\circ}\)

到此,连续一致性模型的主要内容就讲完了。作者最后对连续一致性模型的性能进行了测试,总体而言效果不如离散一致性模型,导致效果不好的原因作者认为是连续一致性模型训练目标的方差较大。宋博士并没有就此止步,随后又提出了sCM方法对连续一致性模型进行改进,使连续一致性模型达到了SOTA的效果。

Refrence

从DDPM到Consistency Models(笔记)

一致性模型Consistency Models

【AI知识分享】一致性模型基本原理解析,110分钟硬核干货分享,这110分钟你绝对花的值!