LinearAttention 概述

Mar 27, 2025
2 views
NLP

概述

众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是 \(\mathcal{O}(n^2)\) 级别的,\(n\) 是序列长度,所以当 \(n\) 比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到 \(\mathcal{O}(n\log n)\) 甚至 \(\mathcal{O}(n)\)

改变这一复杂度的思路主要有两种:

快速预览

其实linear attention的思想很简单,就是把

\[ \mathbf{O} = \operatorname{softmax}(\mathbf{Q}\mathbf{K}^\top) \mathbf{V} \]

的softmax去掉,变成了

\[ \mathbf{O} = (\mathbf{Q}\mathbf{K}^\top) \mathbf{V} \]

然后借助矩阵乘法结合律得到

\[ \mathbf{O} = \mathbf{Q}(\mathbf{K}^\top \mathbf{V}) \]

在双向注意力里,比方说古代的bert时期,以及计算机视觉领域中,这样就已经足够了,大家开开心心地在线性时间内算两个很大的矩阵乘法,甚至都不需要写kernel就能很高效

但是在autoregressive modeling中,我们需要有causal mask。 训练和推理的形式分别是:

\[ \begin{align*} \mathbf{O} &= \operatorname{softmax}(\mathbf{Q}\mathbf{K}^\top \odot \mathbf{M}) \mathbf{V} &&\in \mathbb{R}^{L\times d} \\ \mathbf{o_t} &= \sum_{j=1}^t \frac{\exp(\mathbf{q}_t^\top \mathbf{k}j)}{\sum_{l=1}^t\exp(\mathbf{q}^\top_t \mathbf{k}_l)}\mathbf{v}_j && \in \mathbb{R}^d \end{align*} \]

同样地,把去掉softmax之后,我们可以得到

\[ \begin{aligned}\mathbf{O} &= (\mathbf{Q}\mathbf{K}^\top \odot \mathbf{M}) \mathbf{V} && \in \mathbb{R}^{L \times d} \\ \mathbf{o}_t &= \sum_{j=1}^t (\mathbf{q}_t^\top \mathbf{k}_j) \mathbf{v}_j && \in \mathbb{R}^d \end{aligned} \]

由于这个 \(\mathbf{M}\) 的存在,我们不能直接利用矩阵乘法的结合律得到上面先算KV 矩阵乘法的线性形式(因为矩阵乘法跟矩阵点乘是不可以交换的)

Transformers are RNNs

自回归linear attention的开山之作

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

主要的idea是去掉标准Attention中的Softmax,就可以使得Attention的复杂度退化为理想的\(\mathcal{O}(n)\)级别(Linear Attention)。相比于其他类似的改进结构的工作,这种修改能在把复杂度降到 \(\mathcal{O}(n)\)的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。

其对应的attention计算可以写为去掉softmax并带有kernel的形式:

\[ \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)}\end{equation} \]

利用矩阵乘法的结合律,可以将上式简化为:

\[ \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\phi(q_i)^T \sum_{j=1}^{N} \phi(k_j) v_j^T}{\phi(q_i)^T \sum_{j=1}^{N} \phi(k_j)}\end{equation} \]

详细可以参考:

Performer

从Performer出发思考了线性Attention的一些问题,包括关于线性Attention的激活函数选择,以及线性Attention的瓶颈所在(低秩性、稀疏性),总的结论是,线性Attention的最佳激活函数应当是指数函数,而有效的Attention机制应当具备更高的秩和更大的稀疏性。

详情可以参考:

The Devil in Linear Transformer

这个工作主要的贡献是可以去掉整体的normalization的分母项。并证明了分母带来数值问题,所以在最近的linear attention中几乎全部去掉了,取而代之的是加上output normalization

\[ \begin{equation}O_{norm} = \text{XNorm}(Q(K^TV))\end{equation} \]

详情可以参考:

另外,Fine-Tuning Pre-trained Transformers into Decaying Fast Weights发现QK的activation啥也不设就good enough,后续的RetNet/GLA也不用激活函数,所以这两个term都省掉了。

FLASH

在 Transformers are Rnns 的实现中 linear attention 还存在一个问题:循环训练并行度太差了。此外,linear attention的recurrent update全部都是element-wise的操作(外积,点乘,...),根本没有一丁点矩阵乘法的影子,而矩阵乘法在GPU上非常高效(相同数量的FLOPs,在A100上用tensor cores算半精度矩阵乘法的效率是其他操作的16倍,所以现代的算法都是怎么矩阵乘法怎么来,这也是为什么注意力机制最先被提出来,然后直接席卷deep learning,因为它训的快呀。)

在parallel形式中,我们不需要算任何hidden state,只通过Q K V来得到output,但是需要 $\mathcal{O}(L^2) $ 复杂度。在Recurrent形式中,我们需要算每个time step的hidden state,但是只需要 $\mathcal{O}(L) $ 的复杂度。

那么存不存在一个介于两者之间算法,能够减少recurrent state的数量从而减少循环的次数,同时复杂度依然是线性的呢?

答案就是linear attention的第三种形式:chunkwise parallel form

最早应该是在Transformer Quality in Linear Time 提出的,现在所有的线性注意力训练都是基于chunkwise parallel form。当chunk size \(C=1\) 的时候,它和recurrent form等价,当 \(C=L\) 的时候,它跟parallel form等价。这是一个exact的算法,而不是一个approximate的算法,所以chunk size不影响output的大小

详见:

Lightning Attention

在全局线性注意力中,每个位置的 token 都可以看到整个序列,所以我们可以先计算所有位置的\(\psi(K)^T V\) ,然后再用 \(\phi(Q)\) 与之做点积。

\[ \text{Global Attention}(Q, K, V) = \phi(Q) (\psi(K)^T V) = [q_1, q_2, q_3, q_4] \begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix} \]

我们可以直接使用矩阵乘法:

  1. 计算 \(\psi(K)^T V\)\(\begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个简单的向量乘法。
  2. 计算\(\phi(Q) (\psi(K)^T V)\) : \([q_1, q_2, q_3, q_4] \begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个向量乘以一个标量的运算。
    在这个过程中,我们可以使用矩阵乘法高效并行地完成计算,复杂度为\(O(n)\)

但是,在LLM推理中,我们通常需要「因果性」,每个位置的 token 只能看到它之前的 tokens,所以我们需要为每个位置单独计算注意力,并且要考虑到每个位置可见的 tokens 的数量是不同的。比如:

  • 位置 1:\(x_1\) 只能看到自己, \(text{Output}_1 = q_1 (k_1^T v_1)\)
  • 位置 2:\(x_2\) 可以看到 \(x_1\) ** 和 \(x_2\) ,**\(text{Output}_2 = q_2 (k_1^T v_1 + k_2^T v_2)\)
  • 位置 3:\(x_3\) 可以看到 \(x_1\)\(x_2\) \(x_3\)\(text{Output}_3 = q_3 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3)\)
  • 位置 4:\(x_4\) 可以看到 \(x_1\)\(x_2\)\(x_3\) \(x_4\)\(text{Output}_4 = q_4 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4)\)
    这样的cumsum操作无法被高效地表达为矩阵乘法,因此虽然计算复杂度下来了,但实际运算的效率并不高。

Lightning Attention如何克服传统线性注意力的问题

传统线性注意力虽然降低了复杂度,但在实际实现中面临一个关键问题:cumsum操作。这种操作会导致严重的内存瓶颈和计算效率下降,特别是在处理长序列时。

Lightning Attention 利用了分块技术,有效地规避了cumsum操作带来的问题。从算法1可以看出其实现细节:

image

  1. IO感知的分块策略
  2. 注意力计算的双重分解
  3. 累积矩阵的巧妙应用
  4. 并行优化与硬件友好
    这种优化很像FlashAttention的思路,它们的基本机制相近,都是「在块/小批量维度上将专用计算搬到高速缓存中进行,从而避免大矩阵在显存之间频繁交互」。所以,从工程实现角度,也可以把 Lightning Attention 看作「针对线性注意力的 FlashAttention 思路移植与优化」。

Reference

知乎:Lightning Attention 是如何克服传统线性注意力机制需要累加求和的缺陷的?