本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

Flash Attention 分块原理

Flash Attention 分块原理

更新于 2026-04-23

简介:为什么标准 Attention 的内存是瓶颈

在前面的文章中,我们学习了 Scaled Dot-Product Attention 的计算过程:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

标准实现需要三步:

  1. 计算 S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N} — 存到 HBM
  2. 计算 P=softmax(S)RN×NP = \text{softmax}(S) \in \mathbb{R}^{N \times N} — 存到 HBM
  3. 计算 O=PVRN×dO = PV \in \mathbb{R}^{N \times d} — 存到 HBM

问题在于中间矩阵 SSPP,它们的大小都是 N×NN \times N。当序列长度 NN 较大时(例如 N=4096N = 4096),这两个矩阵就需要 40962×264MB4096^2 \times 2 \approx 64\text{MB}(以 fp16 计算)的内存。更关键的是,这些矩阵需要在 GPU 的 HBM(高带宽内存)中反复读写,而 HBM 的带宽远低于 GPU 的计算速度。

Flash Attention(Dao et al., 2022)提出了一个关键洞察:通过分块计算和 Online Softmax,可以完全避免存储 N×NN \times N 的中间矩阵,将内存从 O(N2)O(N^2) 降到 O(N)O(N),同时大幅减少 HBM 访问次数。

Flash Attention 内存层次问题GPU 内存层次 & Attention 访问模式GPUSRAM~20MB | 19 TB/s12× 差距HBM40-80GB | 1.5 TB/sN×N 矩阵标准 Attention写 N×N 到 HBM,再读回来IO: O(N²)Flash Attention分块留在 SRAM,极少访问 HBMIO: O(N²d²/M)SRAM: 19 TB/s vs HBM: 1.5 TB/s — 12 倍带宽差距

GPU 内存层次:SRAM vs HBM

要理解 Flash Attention 的设计动机,必须先了解 GPU 的内存层次结构。

两级存储

存储层级类型A100 容量带宽特点
SRAM(片上缓存)寄存器 + 共享内存~20MB(每 SM 192KB)~19 TB/s极快,但容量极小
HBM(高带宽内存)显存40-80 GB~1.5-2.0 TB/s容量大,但带宽有限

关键数据:SRAM 的带宽是 HBM 的约 10 倍,但容量只有 HBM 的约 1/2000。

GPU 内存层次与数据搬运对比

SRAM~20MB · 19 TB/sHBM80GB · 2 TB/sbandwidth bottleneck

Standard Attention(6 HBM transfers)

Step 1HBM → SRAMRead Q, K
Step 1SRAM → HBMWrite S = QKᵀ
Step 2HBM → SRAMRead S
Step 2SRAM → HBMWrite P = softmax(S)
Step 3HBM → SRAMRead P, V
Step 3SRAM → HBMWrite O = PV

Flash Attention(2 HBM transfers)

LoadHBM → SRAMRead Q, K, V blocks
Compute⟳ SRAMQKᵀ → scale → mask → softmax → ×V (all in SRAM)
WriteSRAM → HBMWrite final O only

标准 Attention 需要 6 次 HBM 传输(3 读 + 3 写),Flash Attention 只需 2 次(1 读 + 1 写)

标准 Attention 的内存访问模式

标准实现的问题不在于计算量(FLOPs),而在于内存访问量(IO):

Step 1: 从 HBM 读 Q, K → 计算 S = QK^T → 写 S 到 HBM     (读 2Nd, 写 N²)
Step 2: 从 HBM 读 S     → 计算 P = softmax(S) → 写 P 到 HBM (读 N², 写 N²)
Step 3: 从 HBM 读 P, V  → 计算 O = PV → 写 O 到 HBM        (读 N²+Nd, 写 Nd)

总 HBM 访问量:Θ(Nd+N2)\Theta(Nd + N^2)。当 NdN \gg d 时,N2N^2 项主导。

Flash Attention 的目标:通过分块计算,将所有中间结果保留在 SRAM 中,把 HBM 访问量降到 Θ(N2d2M1)\Theta(N^2 d^2 M^{-1}),其中 MM 是 SRAM 大小。

分块策略:Tiling Q, K, V

Flash Attention 的第一个技术是分块(Tiling):将 Q、K、V 分成大小合适的块,使得每个块能完全放入 SRAM。

块大小的选择

设 SRAM 大小为 MM,头维度为 dd

Bc=M4d,Br=min ⁣(M4d,d)B_c = \left\lceil \frac{M}{4d} \right\rceil, \quad B_r = \min\!\left(\left\lceil \frac{M}{4d} \right\rceil, d\right)

这样一个 Br×dB_r \times d 的 Q 块、一个 Bc×dB_c \times d 的 K 块、一个 Bc×dB_c \times d 的 V 块,加上一个 Br×BcB_r \times B_c 的局部分数矩阵,都能放进 SRAM。

分块大小计算器

Bc = ⌈M/(4d)⌉
400
Br = min(Bc, d)
64
Q blocks (Tr)
8
K/V blocks (Tc)
2
Q 矩阵 (512×64)
64×64
64×64
64×64
64×64
64×64
64×64
64×64
64×64
K, V 矩阵 (512×64)
400×64
400×64

SRAM 越大 → 块越大 → 外循环次数越少 → HBM 访问越少(当前共 16 次块计算)

双层循环结构

Flash Attention 使用嵌套循环:

外层循环 (j = 1 to T_c):     // 遍历 K, V 的块
  将 K_j, V_j 从 HBM 加载到 SRAM
  内层循环 (i = 1 to T_r):   // 遍历 Q 的块
    将 Q_i, O_i, l_i, m_i 从 HBM 加载到 SRAM
    在 SRAM 中计算局部注意力
    更新 O_i, l_i, m_i
    写回 HBM

其中 Tc=N/BcT_c = \lceil N / B_c \rceil 是 K/V 的块数,Tr=N/BrT_r = \lceil N / B_r \rceil 是 Q 的块数。

关键:N×NN \times N 的注意力矩阵从未被完整存储。 每次只计算一个 Br×BcB_r \times B_c 的小块,用完即丢。

Online Softmax:核心创新的详细推导

分块计算的最大挑战是 softmax 需要看到整行的数据才能归一化。如果只看到部分列,怎么计算正确的 softmax?

这就是 Online Softmax 要解决的问题。

标准 Softmax 回顾

对一行向量 xRBx \in \mathbb{R}^B,数值稳定的 softmax 是:

m(x)=maxixi,f(x)=[ex1m(x)exBm(x)],(x)=if(x)i,softmax(x)=f(x)(x)m(x) = \max_i x_i, \quad f(x) = \begin{bmatrix} e^{x_1 - m(x)} & \cdots & e^{x_B - m(x)} \end{bmatrix}, \quad \ell(x) = \sum_i f(x)_i, \quad \text{softmax}(x) = \frac{f(x)}{\ell(x)}

其中 m(x)m(x) 是最大值(保证数值稳定),f(x)f(x) 是 shifted 指数向量,(x)\ell(x) 是归一化常数。

拆分为两个块

假设向量 xx 被拆成两部分 x=[x(1),x(2)]x = [x^{(1)}, x^{(2)}],其中 x(1),x(2)RBx^{(1)}, x^{(2)} \in \mathbb{R}^B。我们要证明可以从两部分的局部统计量推导出全局 softmax。

全局最大值可以从局部最大值获得:

m(x)=max ⁣(m(x(1)),m(x(2)))m(x) = \max\!\big(m(x^{(1)}), m(x^{(2)})\big)

全局 shifted 指数向量:

f(x)=[em(x(1))m(x)f(x(1))em(x(2))m(x)f(x(2))]f(x) = \begin{bmatrix} e^{m(x^{(1)}) - m(x)} f(x^{(1)}) & e^{m(x^{(2)}) - m(x)} f(x^{(2)}) \end{bmatrix}

全局归一化常数:

(x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))\ell(x) = e^{m(x^{(1)}) - m(x)} \ell(x^{(1)}) + e^{m(x^{(2)}) - m(x)} \ell(x^{(2)})

关键洞察: 指数修正因子 em(x(1))m(x)e^{m(x^{(1)}) - m(x)} 就是用来补偿局部 max 和全局 max 之差的。如果新块的最大值更大(m(x(2))>m(x(1))m(x^{(2)}) > m(x^{(1)})),之前所有的 eximolde^{x_i - m_{\text{old}}} 都需要乘以 emoldmnewe^{m_{\text{old}} - m_{\text{new}}} 来修正。

递推算法

这个分解可以递推应用于任意数量的块。设处理完第 jj 个块后的统计量为 mjm_jj\ell_jOjO_j。当第 j+1j+1 个块到来时:

Step 1:计算局部分数

S~=QiKj+1T/d\tilde{S} = Q_i K_{j+1}^T / \sqrt{d}

Step 2:计算局部统计量

m~=rowmax(S~),P~=exp(S~m~),~=rowsum(P~)\tilde{m} = \text{rowmax}(\tilde{S}), \quad \tilde{P} = \exp(\tilde{S} - \tilde{m}), \quad \tilde{\ell} = \text{rowsum}(\tilde{P})

Step 3:更新全局统计量

mnew=max(mj,m~)m^{\text{new}} = \max(m_j, \tilde{m}) new=emjmnewj+em~mnew~\ell^{\text{new}} = e^{m_j - m^{\text{new}}} \cdot \ell_j + e^{\tilde{m} - m^{\text{new}}} \cdot \tilde{\ell}

Step 4:修正并更新输出

Onew=diag(new)1 ⁣(diag(j)emjmnewOj+em~mnewP~Vj+1)O^{\text{new}} = \text{diag}(\ell^{\text{new}})^{-1} \!\left( \text{diag}(\ell_j) \cdot e^{m_j - m^{\text{new}}} \cdot O_j + e^{\tilde{m} - m^{\text{new}}} \cdot \tilde{P} \cdot V_{j+1} \right)

这个公式的含义:

  • diag(j)Oj\text{diag}(\ell_j) \cdot O_j:将之前的输出”反归一化”回未除以 \ell 的状态
  • emjmnewe^{m_j - m^{\text{new}}}:修正因子,补偿旧 max 和新 max 之差
  • em~mnewP~Vj+1e^{\tilde{m} - m^{\text{new}}} \cdot \tilde{P} \cdot V_{j+1}:新块的贡献(也要修正到新 max)
  • diag(new)1\text{diag}(\ell^{\text{new}})^{-1}:用新的归一化常数重新归一化

为什么是精确的?

Online Softmax 不是近似,而是数学上精确等价于标准 softmax。整个推导基于简单的代数恒等式:

eximoldemnewmold=eximnew\frac{e^{x_i - m_{\text{old}}}}{e^{m_{\text{new}} - m_{\text{old}}}} = e^{x_i - m_{\text{new}}}

无论数据被分成多少块、以什么顺序处理,最终结果完全相同。

Block 1: 初始化
B1:[2.1, 3.2]
B2:[4.1, 1.5]
B3:[2.8, 3]

s₁ = [2.1, 3.2] → m₁ = max(2.1, 3.2) = 3.2 → exp(s₁ - m₁) = [0.3329, 1.0000] → l₁ = 1.3329

m = 3.2l = 1.3329O = [0.7251, 0.1499]

无需存储完整的 N×N 矩阵,只需维护 m, l, O 三个累积量

交互演示:Flash Attention 分块计算

下面用一个 N=4,d=3,B=2N=4, d=3, B=2 的小例子,逐步演示 Flash Attention 如何处理 Q 的第一个块(t1,t2t_1, t_2),依次与两个 K/V 块交互,并通过 Online Softmax 修正得到精确结果。

Q, K, V 矩阵与分块

标准 Attention 需要存储完整的 N×N 注意力矩阵到 HBM,内存为 O(N²)Flash Attention 的核心思想:将 Q、K、V 分成小块,在 SRAM 中逐块计算,永远不存储完整的 N×N 矩阵

Q ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
0.05
0.11
0.42
t₂
0.03
0.89
0.59
t₃
0.63
0.06
0.25
t₄
-0.56
0.56
0.76
(4, 3)
K ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
0.99
-0.13
0.51
t₂
-0.54
-0.85
0.13
t₃
0.17
-0.34
0.28
t₄
0.42
-0.63
-0.28
(4, 3)
V ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
-0.07
0.10
0.13
t₂
0.89
-0.59
0.14
t₃
-0.29
0.79
0.78
t₄
-0.13
0.65
0.68
(4, 3)
分块:块大小 Br = Bc = 2。高亮行 = 第一个块(t₁, t₂),非高亮行 = 第二个块(t₃, t₄)。我们将以 Q 的第一个块为例,展示如何逐步处理两个 K/V 块。

内存从 O(N2)O(N^2)O(N)O(N) 的推导

标准 Attention 的内存

标准实现需要存储中间矩阵 SSPP

内存=NdQ+NdK+NdV+N2S+N2P+NdO=Θ(Nd+N2)\text{内存} = \underbrace{Nd}_Q + \underbrace{Nd}_K + \underbrace{Nd}_V + \underbrace{N^2}_S + \underbrace{N^2}_P + \underbrace{Nd}_O = \Theta(Nd + N^2)

NdN \gg d 时,O(N2)O(N^2) 项主导。

Flash Attention 的内存

Flash Attention 只需要存储输入、输出和辅助统计量:

内存=NdQ+NdK+NdV+NdO+N+Nm=Θ(Nd)\text{内存} = \underbrace{Nd}_Q + \underbrace{Nd}_K + \underbrace{Nd}_V + \underbrace{Nd}_O + \underbrace{N}_{\ell} + \underbrace{N}_{m} = \Theta(Nd)

没有任何 N2N^2 项!局部的 Br×BcB_r \times B_c 分数矩阵只在 SRAM 中临时存在,不占 HBM。

Theorem 1(Dao et al., 2022):Flash Attention 算法返回 O=softmax(QKT)VO = \text{softmax}(QK^T)V,使用 O(N2d)O(N^2 d) FLOPs,仅需 O(N)O(N) 的额外内存。

IO 复杂度分析:为什么更快

Flash Attention 不仅省内存,更省时间,因为 GPU 上 Attention 的瓶颈不是计算而是内存访问。

标准 Attention 的 IO 复杂度

HBM 访问=Θ(Nd+N2)\text{HBM 访问} = \Theta(Nd + N^2)

Flash Attention 的 IO 复杂度

Theorem 2(Dao et al., 2022):设 NN 为序列长度,dd 为头维度,MM 为 SRAM 大小(dMNdd \leq M \leq Nd)。标准 Attention 需要 Θ(Nd+N2)\Theta(Nd + N^2) 次 HBM 访问;Flash Attention 需要 Θ(N2d2M1)\Theta(N^2 d^2 M^{-1}) 次。

直觉理解:

  • 外层循环遍历 Tc=N/BcT_c = N/B_c 个 K/V 块,每次加载 Θ(Bcd)=Θ(M)\Theta(B_c d) = \Theta(M) 数据
  • 内层循环遍历 Tr=N/BrT_r = N/B_r 个 Q 块,每次加载和写回 Θ(Brd)\Theta(B_r d) 数据
  • 总访问量:Tc×(M+Tr×Brd)=NBc×NBr×Brd=N2dBcT_c \times (M + T_r \times B_r d) = \frac{N}{B_c} \times \frac{N}{B_r} \times B_r d = \frac{N^2 d}{B_c}
  • 因为 Bc=Θ(M/d)B_c = \Theta(M/d),所以 N2dBc=Θ(N2d2/M)\frac{N^2 d}{B_c} = \Theta(N^2 d^2 / M)

对于典型参数(d=64-128d = 64\text{-}128M100KBM \approx 100\text{KB}),d2d^2 远小于 MM,因此 N2d2/MN2N^2 d^2 / M \ll N^2。实验中 Flash Attention 比标准实现快 2-4 倍

IO 复杂度对比:Standard vs Flash v1 vs v2

2565121K2K4K8K16K32K64K序列长度 NHBM 访问量 (log scale)
Standard Θ(Nd+N²)
Flash v1 Θ(N²d²/M)
Flash v2 Θ(N²d/M)

长序列下标准方案 IO 爆炸 vs Flash Attention 的亚二次增长

下界

Proposition 3(Dao et al., 2022):不存在精确 Attention 算法能在所有 M[d,Nd]M \in [d, Nd] 范围内达到 o(N2d2M1)o(N^2 d^2 M^{-1}) 的 HBM 访问量。

这意味着 Flash Attention 在 IO 复杂度意义上是渐近最优的

Flash Attention v1 vs v2

Flash Attention v1 vs v2Flash Attention v1 vs v2v1外层: 遍历 K/V 块内层: 遍历 Q 块多线程块写同一 Q 输出 → 需要同步SM 利用率:SM0SM1SM2SM3SM4SM5部分 SM 空闲A100: 25-40% 利用率v2外层: 遍历 Q 块内层: 遍历 K/V 块每线程块独占 Q 输出 → 无需同步SM 利用率:SM0SM1SM2SM3SM4SM5全部 SM 活跃A100: 50-73% 利用率v2: ~2× 加速 — 更好的并行 + 减少共享内存读写

2023 年,Tri Dao 发布了 Flash Attention v2,在 v1 的基础上进一步优化了 GPU 并行度。

对比项Flash Attention v1Flash Attention v2
外层循环遍历 K/V 块遍历 Q 块
内层循环遍历 Q 块遍历 K/V 块
线程块间并行不同头 & batch 并行额外在 Q 块维度并行
非 matmul FLOPs较多减少,更好利用 Tensor Core
warp 间通信通过共享内存减少 warp 间通信
A100 利用率理论峰值的 25-40%理论峰值的 50-73%
相对加速基准v1 的 ~2x

v2 的关键改进

1. 循环顺序交换

v1 的外层循环遍历 K/V 块,内层遍历 Q 块。v2 反过来:外层遍历 Q 块,内层遍历 K/V 块。这样每个线程块只负责一个 Q 块的输出,减少了同步开销,并且允许在 Q 块维度上并行分配到不同的线程块(streaming multiprocessor)。

2. 减少非 matmul FLOPs

GPU 的 Tensor Core 对矩阵乘法有极高的吞吐量,但 Online Softmax 中的 rescaling、max、sum 等操作是非 matmul FLOPs。v2 通过算法调整减少了这些操作的比例。

3. 更好的 warp 内工作分配

v2 优化了 warp 之间的任务划分,减少了通过共享内存同步的次数,进一步提升了并行效率。

总结

Flash Attention 通过三个核心技术解决了标准 Attention 的内存和速度瓶颈:

技术解决的问题效果
Tiling(分块)N×NN \times N 矩阵不放入 HBM内存 O(N2)O(N)O(N^2) \to O(N)
Online Softmax分块计算中正确归一化精确等价,零近似误差
IO 感知设计减少 HBM 访问次数速度 2-4x 提升

核心公式速查:

mnew=max(mold,m~),new=emoldmnewold+em~mnew~m^{\text{new}} = \max(m^{\text{old}}, \tilde{m}), \quad \ell^{\text{new}} = e^{m^{\text{old}} - m^{\text{new}}} \ell^{\text{old}} + e^{\tilde{m} - m^{\text{new}}} \tilde{\ell} Onew=diag(new)1 ⁣(emoldmnewdiag(old)Oold+em~mnewP~V)O^{\text{new}} = \text{diag}(\ell^{\text{new}})^{-1}\!\left(e^{m^{\text{old}} - m^{\text{new}}} \text{diag}(\ell^{\text{old}}) O^{\text{old}} + e^{\tilde{m} - m^{\text{new}}} \tilde{P} V\right)

Flash Attention 已成为现代大模型推理和训练的标准组件。从 PyTorch 2.0 开始,torch.nn.functional.scaled_dot_product_attention 默认使用 Flash Attention 后端。理解其分块原理,是深入理解 LLM 系统优化的重要基础。