线性注意力与混合架构
标准注意力的 $O(n^2)$ 复杂度是长序列的根本瓶颈。一条研究主线试图把它降到线性 $O(n)$——线性注意力、RWKV、RetNet,以及把它们和 Transformer 缝在一起的混合架构(Jamba 等)。这是 2024–2025 长上下文与高效推理的前沿。标准注意力见 Attention 与变体,状态空间模型见 Mamba/SSM。
一、问题:注意力为什么是 O(n²)?
标准自注意力要算「每个 token 对每个 token」的相似度,得到 $n \times n$ 的注意力矩阵:
- 计算量和显存都随序列长度平方增长;
- 推理时 KV Cache 随序列线性增长,但每步注意力仍要扫全部历史。
序列从 1K 到 100K,平方项让成本爆炸——这是长上下文的核心障碍。
二、线性注意力的核心思想
标准注意力:$\text{softmax}(QK^T)V$,必须先算 $QK^T$($n\times n$)。线性注意力去掉 softmax、用一个核函数 $\phi$ 近似,利用矩阵乘法结合律改变计算顺序:
标准: (Q · Kᵀ) · V 先算 n×n 矩阵 → O(n²)
线性: Q · (Kᵀ · V) 先算 d×d 矩阵 → O(n)关键:$K^TV$ 是一个 $d\times d$ 的「状态」,与序列长度无关。于是注意力可以写成 RNN 式的递推——维护一个固定大小的状态,每来一个 token 更新它:
$$S_t = S_{t-1} + \phi(k_t)v_t^T, \quad o_t = \phi(q_t)S_t$$
这带来一个巨大好处:推理时不再需要随长度增长的 KV Cache,只需一个固定大小的状态——长上下文的显存和速度问题迎刃而解。
三、代表方法
| 方法 | 一句话 | 特点 |
|---|---|---|
| Linear Transformer | 用核函数去 softmax,注意力变 RNN | 线性注意力的奠基 |
| RWKV | 把类 RNN 的递推做成可并行训练的「线性注意力」 | 训练像 Transformer、推理像 RNN |
| RetNet | 保留多尺度衰减的「retention」机制 | 训练并行 / 推理递推 / 分块三种模式 |
| Mamba(SSM) | 选择性状态空间模型,本质也是线性递推 | 见 状态空间模型 |
| GLA、DeltaNet 等 | 带门控/增量更新的线性注意力新变体 | 提升表达力 |
共同范式:训练时并行(像 Transformer 一样快),推理时递推(像 RNN 一样省,常数显存)——鱼和熊掌兼得是这条线的卖点。
四、代价:线性注意力为什么没完全取代 Transformer?
把整段历史压进一个固定大小的状态,必然有信息损失:
- 精确回忆弱:标准注意力能精确「翻看」任意历史 token(关联回忆、抄写长 ID);固定状态记不全细节,在「大海捞针」「精确检索」类任务上弱于 full attention。
- 表达力权衡:去掉 softmax 的非线性,理论表达力受限,需要门控等机制补偿。
这就是为什么纯线性架构在通用基准上长期略逊于同规模 Transformer——省了复杂度,但丢了精确记忆。
五、混合架构(当前主流答案)
既然各有所长,就混着用:大部分层用高效的线性/SSM 层(省、快、长),少数层穿插标准注意力(保留精确回忆能力)。
| 模型 | 做法 |
|---|---|
| Jamba | Mamba 层 + Transformer 注意力层 + MoE 交错堆叠 |
| Mamba-2 / 混合 SSM | SSM 为主,少量注意力层 |
| 各家长上下文模型 | 滑动窗口注意力 + 少量全局注意力层 |
混合架构是「线性层管长程效率,注意力层管精确回忆」的工程折中,在长上下文 + 低成本 + 不丢能力之间取得平衡,被越来越多新模型采用。
与滑动窗口注意力的关系:滑动窗口(每个 token 只看附近 W 个)也是把 $O(n^2)$ 降到 $O(nW)$ 的简单办法,常和少量全局注意力层组合(见 长上下文专题)。
高频追问
Q:线性注意力为什么能做到 O(n)? 去掉 softmax 后,利用矩阵乘法结合律把计算顺序从 $(QK^T)V$ 改成 $Q(K^TV)$;$K^TV$ 是与序列长度无关的 $d\times d$ 状态矩阵,于是整体复杂度从平方降到线性,并能写成「维护固定状态、逐 token 递推」的 RNN 形式。
Q:线性注意力推理时为什么省显存? 它把历史压缩进一个固定大小的状态($d\times d$),不需要随序列增长的 KV Cache。无论上下文多长,状态大小不变——这正是它在长上下文推理上的最大优势。
Q:既然线性注意力又快又省,为什么没取代 Transformer? 因为「固定状态」装不下全部历史细节,精确回忆能力弱:标准注意力能精确翻看任意历史 token,线性注意力在「大海捞针」、精确检索、长程关联任务上明显逊色。省复杂度的代价是丢精确记忆,所以通用能力长期略逊。
Q:RWKV / RetNet 和 Mamba 是一回事吗? 都属于「训练并行、推理递推、常数显存」的线性序列模型大家族,数学上高度相关(都可看作某种线性递推/状态空间)。区别在具体的状态更新规则、衰减/门控设计。Mamba 的卖点是「选择性」(状态更新依赖输入),见 状态空间模型。
Q:混合架构为什么是当前主流方向? 纯注意力贵($O(n^2)$),纯线性丢精确回忆。混合架构让大多数层用线性/SSM(省、快、长上下文),少数层用标准注意力(保住精确检索能力),在成本、长度、能力之间取得最佳折中。Jamba 等是代表。
Q:滑动窗口注意力算线性注意力吗? 不算同一类,但目标相同(降复杂度)。滑动窗口仍是 softmax 注意力,只是限制每个 token 只看邻近 W 个,复杂度降到 $O(nW)$;线性注意力是改变数学形式去掉 softmax。两者常与全局注意力层组合用于长上下文。