Skip to content

FlashAttention 深入

「FlashAttention 为什么能又快又省显存?它是近似算法吗?」是工程/算法岗的硬核进阶题。本文讲清它的核心思想、online softmax、反向重计算,以及 v1/v2/v3 的演进。基础概念见 Attention 与变体

标准注意力的问题:不是算力,是访存

标准注意力计算 O = softmax(QKᵀ/√d)·V,朴素实现会显式地在显存里生成 N×N 的中间矩阵 S = QKᵀP = softmax(S)

瓶颈在于 GPU 的内存层级:

        容量      带宽        延迟
SRAM    ~20MB    ~19 TB/s    极低   ← 计算单元旁的高速缓存
HBM     ~80GB    ~3 TB/s     高     ← 显存

朴素实现要把 N×N 矩阵反复读写 HBM(慢速显存),对长序列来说,注意力是 memory-bound(访存受限) 的——时间几乎都花在搬数据,而非算数。

关键洞察:减少 HBM 读写次数,比减少计算量更能提速。FlashAttention 正是一个 IO 感知(IO-aware) 的算法。

核心思想:分块 + 在 SRAM 里融合计算

FlashAttention 把 Q、K、V 切成小块(tiling),把每一块加载进 SRAM,在 SRAM 内一次性完成 QKᵀ、softmax、乘 V 的全过程(算子融合),绝不把完整的 N×N 矩阵写回 HBM

难点在于:softmax 需要对一整行做归一化(要知道全行的最大值和求和),但分块时我们一次只看到一部分。解法是 online softmax(在线 softmax)

Online Softmax(增量式 softmax)

逐块处理时,维护「当前的最大值 m」和「当前的指数和 l」两个统计量,每来一个新块就:

  1. 更新全局最大值 m_new = max(m_old, 块内最大值)
  2. 用新旧最大值的差,对已累积的结果做**重缩放(rescale)**校正;
  3. 累加本块的贡献。

这样无需看到全行,就能数值稳定地增量算出正确的 softmax 加权和——结果与标准注意力完全一致

所以 FlashAttention 是精确算法,不是近似。它只是换了一种计算顺序,避免了大矩阵落盘。

复杂度收益

  • 显存:从 O(N²)(存中间矩阵)降到 O(N)(只存输出和统计量)。
  • HBM 访问:从 O(N²) 降到约 O(N²/M)(M 为 SRAM 大小),大幅减少访存。
  • 速度:长序列上数倍提速,且支持更长上下文。

反向传播:用重计算换显存

反向传播需要 softmax 矩阵 P,但前向没有保存它(正是为了省显存)。FlashAttention 在反向时重新计算 P(利用前向保存的少量统计量 m、l),用计算换显存——这与 梯度检查点 是同一思路。

v1 → v2 → v3 演进

版本关键改进
v1提出 tiling + online softmax + 重计算,奠定 IO 感知范式
v2优化并行与工作划分:在序列长度维度上并行、减少非矩阵乘(softmax 等)的开销、更好的 warp 间任务分配,GPU 利用率显著提升(约 2×)
v3针对 Hopper(H100):利用异步指令、warp 专门化(warp specialization)、TMA 异步搬数、支持 FP8,进一步压榨硬件

一句话演进:v1 解决「省显存+精确」,v2 解决「并行效率」,v3 吃满「新硬件特性」。

高频追问

Q:FlashAttention 是近似注意力吗? 不是。它通过 online softmax 增量地、数值稳定地算出与标准注意力完全相同的结果。它优化的是计算顺序和访存,不牺牲精度。这区别于稀疏/线性注意力那类近似方法。

Q:它到底快在哪?算力没变啊。 关键不是减少计算量,而是减少对慢速 HBM 的读写。标准实现要把 N×N 中间矩阵反复读写显存(访存受限);FlashAttention 分块后在高速 SRAM 内融合计算,几乎不落盘,访存大幅下降,于是变快。

Q:online softmax 怎么保证数值稳定? softmax 要减去最大值防止指数溢出。分块时维护「当前最大值」,每来新块就更新最大值并对已累积结果重缩放校正,等价于对全行减最大值,因此既正确又稳定。

Q:前向不存注意力矩阵,反向怎么办? 反向时用前向保存的统计量(m、l)和 Q、K、V 重新计算注意力矩阵,用计算换显存,和梯度检查点思路一致。

Q:v2 相比 v1 主要快在哪? v1 主要在 batch 和 head 维并行;v2 增加了序列长度维的并行、减少了 softmax 等非矩阵乘运算的指令开销、改进了 warp 之间的工作划分,使 GPU 占用率更高,速度约翻倍。

Q:FlashAttention 和 PagedAttention 是一回事吗? 不是。FlashAttention 优化单次注意力计算的访存(训练/推理都用);PagedAttention 优化推理时 KV Cache 的显存管理(分页减碎片),二者解决不同问题,可同时使用。

基于 MIT 许可发布