Skip to content

线性注意力与混合架构

标准注意力的 $O(n^2)$ 复杂度是长序列的根本瓶颈。一条研究主线试图把它降到线性 $O(n)$——线性注意力、RWKV、RetNet,以及把它们和 Transformer 缝在一起的混合架构(Jamba 等)。这是 2024–2025 长上下文与高效推理的前沿。标准注意力见 Attention 与变体,状态空间模型见 Mamba/SSM

一、问题:注意力为什么是 O(n²)?

标准自注意力要算「每个 token 对每个 token」的相似度,得到 $n \times n$ 的注意力矩阵:

  • 计算量和显存都随序列长度平方增长;
  • 推理时 KV Cache 随序列线性增长,但每步注意力仍要扫全部历史。

序列从 1K 到 100K,平方项让成本爆炸——这是长上下文的核心障碍。

二、线性注意力的核心思想

标准注意力:$\text{softmax}(QK^T)V$,必须先算 $QK^T$($n\times n$)。线性注意力去掉 softmax、用一个核函数 $\phi$ 近似,利用矩阵乘法结合律改变计算顺序:

标准: (Q · Kᵀ) · V   先算 n×n 矩阵 → O(n²)
线性: Q · (Kᵀ · V)   先算 d×d 矩阵 → O(n)

关键:$K^TV$ 是一个 $d\times d$ 的「状态」,与序列长度无关。于是注意力可以写成 RNN 式的递推——维护一个固定大小的状态,每来一个 token 更新它:

$$S_t = S_{t-1} + \phi(k_t)v_t^T, \quad o_t = \phi(q_t)S_t$$

这带来一个巨大好处:推理时不再需要随长度增长的 KV Cache,只需一个固定大小的状态——长上下文的显存和速度问题迎刃而解。

三、代表方法

方法一句话特点
Linear Transformer用核函数去 softmax,注意力变 RNN线性注意力的奠基
RWKV把类 RNN 的递推做成可并行训练的「线性注意力」训练像 Transformer、推理像 RNN
RetNet保留多尺度衰减的「retention」机制训练并行 / 推理递推 / 分块三种模式
Mamba(SSM)选择性状态空间模型,本质也是线性递推状态空间模型
GLA、DeltaNet 等带门控/增量更新的线性注意力新变体提升表达力

共同范式:训练时并行(像 Transformer 一样快),推理时递推(像 RNN 一样省,常数显存)——鱼和熊掌兼得是这条线的卖点。

四、代价:线性注意力为什么没完全取代 Transformer?

把整段历史压进一个固定大小的状态,必然有信息损失:

  • 精确回忆弱:标准注意力能精确「翻看」任意历史 token(关联回忆、抄写长 ID);固定状态记不全细节,在「大海捞针」「精确检索」类任务上弱于 full attention。
  • 表达力权衡:去掉 softmax 的非线性,理论表达力受限,需要门控等机制补偿。

这就是为什么纯线性架构在通用基准上长期略逊于同规模 Transformer——省了复杂度,但丢了精确记忆

五、混合架构(当前主流答案)

既然各有所长,就混着用:大部分层用高效的线性/SSM 层(省、快、长),少数层穿插标准注意力(保留精确回忆能力)。

模型做法
JambaMamba 层 + Transformer 注意力层 + MoE 交错堆叠
Mamba-2 / 混合 SSMSSM 为主,少量注意力层
各家长上下文模型滑动窗口注意力 + 少量全局注意力层

混合架构是「线性层管长程效率,注意力层管精确回忆」的工程折中,在长上下文 + 低成本 + 不丢能力之间取得平衡,被越来越多新模型采用。

与滑动窗口注意力的关系:滑动窗口(每个 token 只看附近 W 个)也是把 $O(n^2)$ 降到 $O(nW)$ 的简单办法,常和少量全局注意力层组合(见 长上下文专题)。

高频追问

Q:线性注意力为什么能做到 O(n)? 去掉 softmax 后,利用矩阵乘法结合律把计算顺序从 $(QK^T)V$ 改成 $Q(K^TV)$;$K^TV$ 是与序列长度无关的 $d\times d$ 状态矩阵,于是整体复杂度从平方降到线性,并能写成「维护固定状态、逐 token 递推」的 RNN 形式。

Q:线性注意力推理时为什么省显存? 它把历史压缩进一个固定大小的状态($d\times d$),不需要随序列增长的 KV Cache。无论上下文多长,状态大小不变——这正是它在长上下文推理上的最大优势。

Q:既然线性注意力又快又省,为什么没取代 Transformer? 因为「固定状态」装不下全部历史细节,精确回忆能力弱:标准注意力能精确翻看任意历史 token,线性注意力在「大海捞针」、精确检索、长程关联任务上明显逊色。省复杂度的代价是丢精确记忆,所以通用能力长期略逊。

Q:RWKV / RetNet 和 Mamba 是一回事吗? 都属于「训练并行、推理递推、常数显存」的线性序列模型大家族,数学上高度相关(都可看作某种线性递推/状态空间)。区别在具体的状态更新规则、衰减/门控设计。Mamba 的卖点是「选择性」(状态更新依赖输入),见 状态空间模型

Q:混合架构为什么是当前主流方向? 纯注意力贵($O(n^2)$),纯线性丢精确回忆。混合架构让大多数层用线性/SSM(省、快、长上下文),少数层用标准注意力(保住精确检索能力),在成本、长度、能力之间取得最佳折中。Jamba 等是代表。

Q:滑动窗口注意力算线性注意力吗? 不算同一类,但目标相同(降复杂度)。滑动窗口仍是 softmax 注意力,只是限制每个 token 只看邻近 W 个,复杂度降到 $O(nW)$;线性注意力是改变数学形式去掉 softmax。两者常与全局注意力层组合用于长上下文。

基于 MIT 许可发布