Reading

Consistency Models

Diffusion Models from SDE

连续扩散模型 (Continuous Diffusion Models) 将传统的离散时间扩散过程扩展到连续时间域,可以被视为一个随机过程,使用随机微分方程(SDE)来描述。其前向过程可以写成如下形式:

其中, 可以看成偏移系数, 可以看成是扩散系数, 是标准布朗运动。这个SDE
描述了数据在连续时间域内如何被噪声逐渐破坏。
这个随机过程的逆向过程存在(更准确的描述:下面的逆向时间SDE具有与正向过程SDE相同的联合分布)为

前面我们得到了扩散过程的逆向过程可以用一个SDE描述(逆向随机过程),事实上,**存在一个确定性过程 (用ODE描述)也是它的逆向过程 (更准确的描述:这个ODE过程的在任意时刻 ** 的状态的边缘分布 与 SDE过程的相同

其中, 叫做score function , 可以通过score match进行训练,也兼容普通的扩散模型噪声预测模型。

这个ODE称为概率流ODE(probability flow ODE,PF ODE,它沿着概率流的轨迹(也就是ODE的解函数) 建立了从噪声分布中的点 与到数据分布中的点 的映射,也就是说PF ODE建立了高斯分布样本与数据分布样本的映射。

因此,采样时,把 代入PF-ODE, 仅需采样一次, 然后可以确定性的解出一条从的轨迹,这样的到的的分布就是data分布。可以用ODE数值求解器来求解出 ,例如欧拉法,Runge- Kutta法之类的数值方法。

Consistency Models

有了PF ODE,我们已经能做到采样一个高斯噪声,然后通过求解ODE,映射到数据。但求解ODE 仍然是一个迭代过程,要算很多次 ,效率很低。我们能不能直接拟合出来这个ODE表示的映射 (也就是ODE的解函数),然后一步到位?Consistency models (一致性模型)就是相当于直接学了个ODE的解:任给某个轨迹上的点 ,一步inference得到 !
首先我们回顾扩散过程SDE。在Consistency Models这篇论文里,为了最后的PF ODE形式的简单,他们采用了扩散项系数 ,漂流项 为0,

这样 ,其中是卷积操作,。用 score model 进行估计。在上述设置下, 带入到1式3式, 可以写作:

式5 为 empirical PF ODE。

image

对于式4的 SDE, 可以对两边进行的积分:

其中第二行利用的是随机微分方程SDE和基于 Ito 积分的性质,也可以理解为这个等式右侧是对正态分布进行积分,相当于无穷个带有系数 的正态分布的相加,那么根据正态分布的线性性质,可以知道是无穷个个正态分布相加,均值还是 0,方差为 最后一行相当于一个加噪的过程,由于方差 一般比较大,所以可以近似 ,同样的有任意时间步 的加噪过程

* *中采样初始噪声,然后使用 Euler/Heun 求解器求解式5, 可以得到一条采样轨迹 。为了数值稳定性,通常在 处停止采样得到 , 可以近似认为满足分布。当 不是从 0 开始而是从 开始,对应的加噪过程也有:

由于, 所以先验分布仍然可以近似为 。并且,类似的有:

定义

一致性函数

在给定概率流常微分方程(Probability Flow Ordinary Differential Equation, PFODE)的情况下,一致性函数可表示为下面的形
式:

一致性模型的学习目标就是在给定PFODE上,得到一个神经网络使得任意相邻两个时刻的输出尽可能相等,进而逼近一致性函数。

如下图所示:

image

神经网络 的形式为:

其中 是两个关于时间的可微函数,保证,满足一致性函数的要求。

def denoise(self, model, x_t, sigmas, **model_kwargs):
    import torch.distributed as dist

    if not self.distillation:
        c_skip, c_out, c_in = [
            append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)
        ]
    else:
        c_skip, c_out, c_in = [
            append_dims(x, x_t.ndim)
            for x in self.get_scalings_for_boundary_condition(sigmas)
        ]

    rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
    denoised = c_out * model_output + c_skip * x_t
    return model_output, denoised

可以发现,在实现时引入多了一个 ,来限制输入图像,有点像LSTM的输入门。 为人为设定,和时间相关,当增加时,图像噪声水平增加,离采样终点越远远, 的值会下降, 的值会上升,更少的输入信号 被保留,更多依靠模型去进行预测。相反,当 减少时,图像噪声水平下降,离采样终点越趋近, 的值会上升, 的值会下降,更多的输入信号 被保留,更少依靠模型去进行预测。至于 ,可以参考EDM算法

采样

一致性模型 训练好后,可以先从先验分布中采样 ,然后通过*
*进行一步采样。
除此之外,还可以通过交替去噪和加噪 (在不同的时间点上进行一步加噪和一步去噪),提高样本
质量,即:

image

其中一步加噪使用的就是式6

训练Loss

一致性模型采用了让PF-ODE相邻两个时间点模型输出值差距最小化的方式实现利用神经网络逼近一致性函数的目标,损失函数可以写为:

其中, 表示时间点设置的数目,是通过一种ODE求解器获得的上一个时刻 的图像,这里用 表示第 个时间点对应的时间, 且为整数。 是一个超参数,理论这个值越大,在PF-ODE上的路径点数目就会越多,两点之间越靠近, 的求解会更准,模型精度越好。实际上当足够大以后,这个数值对模型性能影响已经不敏感。作者在这里并没有直接使用上面的损失函数形式,而是将上式后一项的 的权重 换成了模型的指数滑动平均值 (Exponential Moving Average, EMA) 。根据EMA的性质,给定衰减系数,可得:

此时,损失函数可写为:

该式便是论文中损失函数的一种简化形式。采用EMA相当于用一个多轮加权平均的结果充当最终的模型,每一步不同的模型权重 对其影响较小,因此又被称作"目标模型"。采用EMA的原因是可以提升训练过程的稳定性,提升一致性模型的效果。有了损失函数的形式了,就可以考虑如何训练模型了。论文中给定了两种方法, 一种是从已有模型切入,也即一致性蒸馏 (Consistency Distillation, CD),一种是从零开始训练一个新的一致性模型,也即一致性训练(Consistency Training, CT).

一致性蒸馏 (Consistency Distillation, CD)

一致性模型的训练可以从已有模型切入,比如采用score-based模型进行蒸馏。假设现在已经有一个模型, 很自然地,通过模型预测出一个score的值,再通过欧拉法采样一步,就可以获得在PF-ODE路径上相邻点的位置。根据PF-ODE形式,很显然有:

代入10式损失函数中可得CD的损失函数形式为:

论文中给出了CD过程的算法步骤如下,结合损失函数来看不难理解。首先从数据中采样样本,均匀采样时间点,根据时间加噪声。加完噪声后,采用ODE数值求解器获得上一个时间点的图像,紧接着计算损失函数并进行梯度反传。

image

其中, 表示模型学习的"司机”, 可以是score,也可以是速度场等ODE迭代必须元素。是权重函数,作者设置恒等于 1, 是衡量两个输入距离的指标,作者考虑了 L1、L2 和 LPIPS。需要注意的是,每一步需要同步更新EMA权重的值。

一致性训练(Consistency Training, CT)

一致性模型难道一定要背靠一个现成的扩散模型吗?也可以从零开始训练一个一致性模型。可以发现,没有现成扩散模型的最大问题是 没了,没办法进行ODE的数值求解了。作者就给定了一个score-based模型预测结果的平替,用这个平替让模型进行学习,进而替代。这个平替需要至少满足以下两个条件:

  1. 的无偏估计
  2. 需要在训练过程中作为金标准,因此需要可以计算获得
    基于上述两个条件,作者给出了平替的形式为:

其中,,可以发现。

论文中推导的前几步可能稍微难理解一点,在这里给大家全部补上。回忆一下概率论的基本概念,根据边缘概率密度与联合概率密度的关系,可得如下的等式:

接下来就是公式推导,公式推导的目的就是从score的定义出发,探索到底有没有一个关于score的无偏估计量。总体而言,宋博士在这里的推导过程相当清晰,几乎没有跳步,仅在开始两步稍微有点难度,在这里给大家一一分解。根据score的定义,关于求导,实际上是一个复合函数求导!也即先对log求导,再对求导!心中牢记复合函数求导,则推导过程并不难理解,如下所示:

通过推导可以发现,这个平替确实是score的无偏估计。

既然无偏估计出现了,对于原始损失函数,只需要把原来依靠模型预测值预测的直接写为平替形式
也即

所以,CT的损失函数形式为:

论文给出了基于CT损失函数的模型训练算法步骤,可以发现相比CD来说,主要不同点为相邻时刻(也即上一时刻)
的图像直接通过原始图像加噪而非ODE迭代获得 (与公式(16)等价)。

image

如何证明这个估计是正确有效的?假如我们已经有了一个ground truth score function
,可以证明这个估计的loss与给定ground truth score function的
loss之间的误差趋于0,即。作者通过泰勒展开证明
了这个结论。

在实验中,步数N随训练逐渐增加,可以加快收敛。

image

效果上,目前diffusion > CD > diffusion distillation > CT,(GAN的效果不太好比较,在有些数据集上,经过这么多年充分调教的GAN性能甚至再次反超diffusion,有的数据集上则不如CT)。作为单步方法的CD和CT打不过diffusion,其实可以理解,一方面这one-one mapping本身非常复杂,还是比较难拟合的;另一方面,consistency models训练方法是约束一条ODE轨迹上的相邻点输出相同,模型拟合不好的话,每一步的误差都会累积下来。

不过consistency model才刚刚提出不久,在训练方法上还有很多待研究的技术。而作为一种单步生成方法,相对于diffusion在性能上有着显著优势。

连续一致性模型

损失函数推导:以一致性蒸馏中 为例

上述的一致性模型是基于离散时间序列 ,最典型的体现就是得用一个 ODE 求解器从 * 获得 *。ODE 求解器一定存在由于离散时间导致的截断误差,也就从理论上无法达到最优。离散到连续的转化主要体现在 ODE 求解器的迭代公式,在离散情况下, 可以近似为 , 在连续情况下我们有一个关于 的连续时间函数 ,所以 。 其中,,也即

  • 离散情形:
  • 连续情形:
    上述看似简单的改变会导致一致性模型学习目标(损失函数)的改变,一致性模型论文分了很多种情况进行推导讨论。让我们先考虑最简单的情况,先认为 ,也即不使用 EMA 权重,考虑一致性蒸馏场景。令 **,**连续形式的损失函数可以改写为:

损失函数(16)是以一种“如连续”的状态,你说它是连续的,它毕竟没有显式的离散两个时间点的差了;你说它不是连续的,它毕竟还存在一个好像不那么和谐的 ,而且 还是存在的。公式(16)虽然形式有了,但很难用代码实现。现在看来似乎没有什么办法了,毫无思路。毫无思路的时候,记住一个口诀:泰勒救我! 光有泰勒显然还不够,我们回想一致性模型的学习目标是希望相邻两个时间点的输出尽可能的接近,最简单的想法就是,先无脑求一下下面这个式子:

很显然,直接求没法求,这个时候泰勒展开就能出场了:

公式(18)暂时现放着,后面有用。还有一个式子似乎可以用泰勒展开,那就是距离度量函数,也即:

从始至终,影响我们无法进行实现连续一致性模型的"罪魁祸首"就是这个, 既然看它不爽,那我
们对公式(19)的第二个输入 处进行泰勒展开。需要注意,这里实际上是一个"一元”函数的泰勒展开,有:

这里引入了一个新的矩阵,它的具体形式如下:

这个矩阵实际上就是黑塞 (Hessian) 矩阵,记录了一个标量输出的函数与输入向量所有元素可能的二阶偏导数。注意公式(20),有一些项是可以化简的:

  1. 对于距离度量,在同一位置的两点距离为0,也即,所以有:
  2. 对于距离度量, 当两点距离位置相同时的梯度为0,也即。所以有: