简介:为什么标准 Attention 的内存是瓶颈
在前面的文章中,我们学习了 Scaled Dot-Product Attention 的计算过程:
Attention(Q,K,V)=softmax(dkQKT)V
标准实现需要三步:
- 计算 S=QKT∈RN×N — 存到 HBM
- 计算 P=softmax(S)∈RN×N — 存到 HBM
- 计算 O=PV∈RN×d — 存到 HBM
问题在于中间矩阵 S 和 P,它们的大小都是 N×N。当序列长度 N 较大时(例如 N=4096),这两个矩阵就需要 40962×2≈64MB(以 fp16 计算)的内存。更关键的是,这些矩阵需要在 GPU 的 HBM(高带宽内存)中反复读写,而 HBM 的带宽远低于 GPU 的计算速度。
Flash Attention(Dao et al., 2022)提出了一个关键洞察:通过分块计算和 Online Softmax,可以完全避免存储 N×N 的中间矩阵,将内存从 O(N2) 降到 O(N),同时大幅减少 HBM 访问次数。
GPU 内存层次:SRAM vs HBM
要理解 Flash Attention 的设计动机,必须先了解 GPU 的内存层次结构。
两级存储
| 存储层级 | 类型 | A100 容量 | 带宽 | 特点 |
|---|
| SRAM(片上缓存) | 寄存器 + 共享内存 | ~20MB(每 SM 192KB) | ~19 TB/s | 极快,但容量极小 |
| HBM(高带宽内存) | 显存 | 40-80 GB | ~1.5-2.0 TB/s | 容量大,但带宽有限 |
关键数据:SRAM 的带宽是 HBM 的约 10 倍,但容量只有 HBM 的约 1/2000。
GPU 内存层次与数据搬运对比
Standard Attention(6 HBM transfers)
Step 1HBM → SRAMRead Q, K
Step 1SRAM → HBMWrite S = QKᵀ
Step 2HBM → SRAMRead S
Step 2SRAM → HBMWrite P = softmax(S)
Step 3HBM → SRAMRead P, V
Step 3SRAM → HBMWrite O = PV
Flash Attention(2 HBM transfers)
LoadHBM → SRAMRead Q, K, V blocks
Compute⟳ SRAMQKᵀ → scale → mask → softmax → ×V (all in SRAM)
WriteSRAM → HBMWrite final O only
标准 Attention 需要 6 次 HBM 传输(3 读 + 3 写),Flash Attention 只需 2 次(1 读 + 1 写)
标准 Attention 的内存访问模式
标准实现的问题不在于计算量(FLOPs),而在于内存访问量(IO):
Step 1: 从 HBM 读 Q, K → 计算 S = QK^T → 写 S 到 HBM (读 2Nd, 写 N²)
Step 2: 从 HBM 读 S → 计算 P = softmax(S) → 写 P 到 HBM (读 N², 写 N²)
Step 3: 从 HBM 读 P, V → 计算 O = PV → 写 O 到 HBM (读 N²+Nd, 写 Nd)
总 HBM 访问量:Θ(Nd+N2)。当 N≫d 时,N2 项主导。
Flash Attention 的目标:通过分块计算,将所有中间结果保留在 SRAM 中,把 HBM 访问量降到 Θ(N2d2M−1),其中 M 是 SRAM 大小。
分块策略:Tiling Q, K, V
Flash Attention 的第一个技术是分块(Tiling):将 Q、K、V 分成大小合适的块,使得每个块能完全放入 SRAM。
块大小的选择
设 SRAM 大小为 M,头维度为 d:
Bc=⌈4dM⌉,Br=min(⌈4dM⌉,d)
这样一个 Br×d 的 Q 块、一个 Bc×d 的 K 块、一个 Bc×d 的 V 块,加上一个 Br×Bc 的局部分数矩阵,都能放进 SRAM。
分块大小计算器
Q 矩阵 (512×64)
64×64
64×64
64×64
64×64
64×64
64×64
64×64
64×64
SRAM 越大 → 块越大 → 外循环次数越少 → HBM 访问越少(当前共 16 次块计算)
双层循环结构
Flash Attention 使用嵌套循环:
外层循环 (j = 1 to T_c): // 遍历 K, V 的块
将 K_j, V_j 从 HBM 加载到 SRAM
内层循环 (i = 1 to T_r): // 遍历 Q 的块
将 Q_i, O_i, l_i, m_i 从 HBM 加载到 SRAM
在 SRAM 中计算局部注意力
更新 O_i, l_i, m_i
写回 HBM
其中 Tc=⌈N/Bc⌉ 是 K/V 的块数,Tr=⌈N/Br⌉ 是 Q 的块数。
关键:N×N 的注意力矩阵从未被完整存储。 每次只计算一个 Br×Bc 的小块,用完即丢。
Online Softmax:核心创新的详细推导
分块计算的最大挑战是 softmax 需要看到整行的数据才能归一化。如果只看到部分列,怎么计算正确的 softmax?
这就是 Online Softmax 要解决的问题。
标准 Softmax 回顾
对一行向量 x∈RB,数值稳定的 softmax 是:
m(x)=imaxxi,f(x)=[ex1−m(x)⋯exB−m(x)],ℓ(x)=i∑f(x)i,softmax(x)=ℓ(x)f(x)
其中 m(x) 是最大值(保证数值稳定),f(x) 是 shifted 指数向量,ℓ(x) 是归一化常数。
拆分为两个块
假设向量 x 被拆成两部分 x=[x(1),x(2)],其中 x(1),x(2)∈RB。我们要证明可以从两部分的局部统计量推导出全局 softmax。
全局最大值可以从局部最大值获得:
m(x)=max(m(x(1)),m(x(2)))
全局 shifted 指数向量:
f(x)=[em(x(1))−m(x)f(x(1))em(x(2))−m(x)f(x(2))]
全局归一化常数:
ℓ(x)=em(x(1))−m(x)ℓ(x(1))+em(x(2))−m(x)ℓ(x(2))
关键洞察: 指数修正因子 em(x(1))−m(x) 就是用来补偿局部 max 和全局 max 之差的。如果新块的最大值更大(m(x(2))>m(x(1))),之前所有的 exi−mold 都需要乘以 emold−mnew 来修正。
递推算法
这个分解可以递推应用于任意数量的块。设处理完第 j 个块后的统计量为 mj、ℓj、Oj。当第 j+1 个块到来时:
Step 1:计算局部分数
S~=QiKj+1T/d
Step 2:计算局部统计量
m~=rowmax(S~),P~=exp(S~−m~),ℓ~=rowsum(P~)
Step 3:更新全局统计量
mnew=max(mj,m~)
ℓnew=emj−mnew⋅ℓj+em~−mnew⋅ℓ~
Step 4:修正并更新输出
Onew=diag(ℓnew)−1(diag(ℓj)⋅emj−mnew⋅Oj+em~−mnew⋅P~⋅Vj+1)
这个公式的含义:
- diag(ℓj)⋅Oj:将之前的输出”反归一化”回未除以 ℓ 的状态
- emj−mnew:修正因子,补偿旧 max 和新 max 之差
- em~−mnew⋅P~⋅Vj+1:新块的贡献(也要修正到新 max)
- diag(ℓnew)−1:用新的归一化常数重新归一化
为什么是精确的?
Online Softmax 不是近似,而是数学上精确等价于标准 softmax。整个推导基于简单的代数恒等式:
emnew−moldexi−mold=exi−mnew
无论数据被分成多少块、以什么顺序处理,最终结果完全相同。
Block 1: 初始化
B1:[2.1, 3.2]
B2:[4.1, 1.5]
B3:[2.8, 3]
s₁ = [2.1, 3.2] → m₁ = max(2.1, 3.2) = 3.2 → exp(s₁ - m₁) = [0.3329, 1.0000] → l₁ = 1.3329
m = 3.2l = 1.3329O = [0.7251, 0.1499]
无需存储完整的 N×N 矩阵,只需维护 m, l, O 三个累积量
交互演示:Flash Attention 分块计算
下面用一个 N=4,d=3,B=2 的小例子,逐步演示 Flash Attention 如何处理 Q 的第一个块(t1,t2),依次与两个 K/V 块交互,并通过 Online Softmax 修正得到精确结果。
Q, K, V 矩阵与分块
标准 Attention 需要存储完整的 N×N 注意力矩阵到 HBM,内存为 O(N²)。Flash Attention 的核心思想:将 Q、K、V 分成小块,在 SRAM 中逐块计算,永远不存储完整的 N×N 矩阵。
分块:块大小 Br = Bc = 2。高亮行 = 第一个块(t₁, t₂),非高亮行 = 第二个块(t₃, t₄)。我们将以 Q 的第一个块为例,展示如何逐步处理两个 K/V 块。
内存从 O(N2) 到 O(N) 的推导
标准 Attention 的内存
标准实现需要存储中间矩阵 S 和 P:
内存=QNd+KNd+VNd+SN2+PN2+ONd=Θ(Nd+N2)
当 N≫d 时,O(N2) 项主导。
Flash Attention 的内存
Flash Attention 只需要存储输入、输出和辅助统计量:
内存=QNd+KNd+VNd+ONd+ℓN+mN=Θ(Nd)
没有任何 N2 项!局部的 Br×Bc 分数矩阵只在 SRAM 中临时存在,不占 HBM。
Theorem 1(Dao et al., 2022):Flash Attention 算法返回 O=softmax(QKT)V,使用 O(N2d) FLOPs,仅需 O(N) 的额外内存。
IO 复杂度分析:为什么更快
Flash Attention 不仅省内存,更省时间,因为 GPU 上 Attention 的瓶颈不是计算而是内存访问。
标准 Attention 的 IO 复杂度
HBM 访问=Θ(Nd+N2)
Flash Attention 的 IO 复杂度
Theorem 2(Dao et al., 2022):设 N 为序列长度,d 为头维度,M 为 SRAM 大小(d≤M≤Nd)。标准 Attention 需要 Θ(Nd+N2) 次 HBM 访问;Flash Attention 需要 Θ(N2d2M−1) 次。
直觉理解:
- 外层循环遍历 Tc=N/Bc 个 K/V 块,每次加载 Θ(Bcd)=Θ(M) 数据
- 内层循环遍历 Tr=N/Br 个 Q 块,每次加载和写回 Θ(Brd) 数据
- 总访问量:Tc×(M+Tr×Brd)=BcN×BrN×Brd=BcN2d
- 因为 Bc=Θ(M/d),所以 BcN2d=Θ(N2d2/M)
对于典型参数(d=64-128,M≈100KB),d2 远小于 M,因此 N2d2/M≪N2。实验中 Flash Attention 比标准实现快 2-4 倍。
IO 复杂度对比:Standard vs Flash v1 vs v2
长序列下标准方案 IO 爆炸 vs Flash Attention 的亚二次增长
下界
Proposition 3(Dao et al., 2022):不存在精确 Attention 算法能在所有 M∈[d,Nd] 范围内达到 o(N2d2M−1) 的 HBM 访问量。
这意味着 Flash Attention 在 IO 复杂度意义上是渐近最优的。
Flash Attention v1 vs v2
2023 年,Tri Dao 发布了 Flash Attention v2,在 v1 的基础上进一步优化了 GPU 并行度。
| 对比项 | Flash Attention v1 | Flash Attention v2 |
|---|
| 外层循环 | 遍历 K/V 块 | 遍历 Q 块 |
| 内层循环 | 遍历 Q 块 | 遍历 K/V 块 |
| 线程块间并行 | 不同头 & batch 并行 | 额外在 Q 块维度并行 |
| 非 matmul FLOPs | 较多 | 减少,更好利用 Tensor Core |
| warp 间通信 | 通过共享内存 | 减少 warp 间通信 |
| A100 利用率 | 理论峰值的 25-40% | 理论峰值的 50-73% |
| 相对加速 | 基准 | v1 的 ~2x |
v2 的关键改进
1. 循环顺序交换
v1 的外层循环遍历 K/V 块,内层遍历 Q 块。v2 反过来:外层遍历 Q 块,内层遍历 K/V 块。这样每个线程块只负责一个 Q 块的输出,减少了同步开销,并且允许在 Q 块维度上并行分配到不同的线程块(streaming multiprocessor)。
2. 减少非 matmul FLOPs
GPU 的 Tensor Core 对矩阵乘法有极高的吞吐量,但 Online Softmax 中的 rescaling、max、sum 等操作是非 matmul FLOPs。v2 通过算法调整减少了这些操作的比例。
3. 更好的 warp 内工作分配
v2 优化了 warp 之间的任务划分,减少了通过共享内存同步的次数,进一步提升了并行效率。
总结
Flash Attention 通过三个核心技术解决了标准 Attention 的内存和速度瓶颈:
| 技术 | 解决的问题 | 效果 |
|---|
| Tiling(分块) | N×N 矩阵不放入 HBM | 内存 O(N2)→O(N) |
| Online Softmax | 分块计算中正确归一化 | 精确等价,零近似误差 |
| IO 感知设计 | 减少 HBM 访问次数 | 速度 2-4x 提升 |
核心公式速查:
mnew=max(mold,m~),ℓnew=emold−mnewℓold+em~−mnewℓ~
Onew=diag(ℓnew)−1(emold−mnewdiag(ℓold)Oold+em~−mnewP~V)
Flash Attention 已成为现代大模型推理和训练的标准组件。从 PyTorch 2.0 开始,torch.nn.functional.scaled_dot_product_attention 默认使用 Flash Attention 后端。理解其分块原理,是深入理解 LLM 系统优化的重要基础。