本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

算子融合(下):Cost Model 与融合实战

算子融合(下):Cost Model 与融合实战

更新于 2026-04-23

查看全景图用户代码全景图计算图捕获IR 设计优化 Pass算子融合9. Cost Model你在这里代码生成调度与执行硬件执行

简介

上一篇我们讨论了融合的 WHAT(五种融合类型)和 WHEN(合法性判定算法)。本文回答两个更关键的问题:

  1. WHETHER — 融合是否有益?不是所有合法的融合都值得做。
  2. HOW — 实践中如何实现高效融合?从 TorchInductor 的启发式到 FlashAttention 的算法改写。

核心观点:“能融合”不等于”该融合”。盲目融合可能导致 register pressure 增大、occupancy 下降、编译时间爆炸。成熟的编译器需要 cost model 来做出明智的融合决策。

Cost Model 设计

为什么不能”融合一切”?

融合并非越多越好。下图展示了融合程度与性能之间的典型关系——存在一个由 cost model 指导的最优区间:

融合程度与性能的权衡
融合程度与性能的权衡融合程度 (Fusion Degree)性能 (Performance)激进最优点融合不足过多 HBM 访问过度融合寄存器溢出 / 低占用率Cost Model 指导区间

假设我们有一个合法融合候选 A+B。从上篇我们知道它满足所有合法性条件。但融合后的 kernel 可能比两个独立 kernel 更慢。原因有三:

1. Register Pressure(寄存器压力)

每个 GPU 线程拥有有限的 register。NVIDIA GPU 的 SM(Streaming Multiprocessor)通常有 65536 个 32-bit register。如果一个 kernel 的每个线程需要 32 个 register,blockSize 为 256:

Registers per block=32×256=8192\text{Registers per block} = 32 \times 256 = 8192 Max blocks per SM=65536/8192=8\text{Max blocks per SM} = \lfloor 65536 / 8192 \rfloor = 8

融合后如果 register 需求从 32 增加到 48:

Registers per block=48×256=12288\text{Registers per block} = 48 \times 256 = 12288 Max blocks per SM=65536/12288=5\text{Max blocks per SM} = \lfloor 65536 / 12288 \rfloor = 5

blocks 从 8 降到 5,意味着 SM 上能同时运行的 warp 减少了 37.5%。

2. Occupancy(占用率)

Occupancy 是 GPU 性能的关键指标,定义为:

Occupancy=Active warps per SMMax warps per SM\text{Occupancy} = \frac{\text{Active warps per SM}}{\text{Max warps per SM}}

影响 occupancy 的三大约束:

  • Register 文件大小:如上所述,register 用量越多,能启动的 warp 越少
  • Shared memory 容量:V100 = 96 KB, A100 = 164 KB, H100 = 228 KB。如果一个 block 需要 32 KB shared memory,A100 上最多 5 个 block(164/32≈5)
  • Max threads per block:硬件限制(通常 1024)

低 occupancy 意味着 GPU 难以通过 warp 切换来隐藏内存延迟(latency hiding),导致实际吞吐量远低于峰值。

3. 编译时间

Triton 和 CUDA kernel 的编译时间大致与 kernel 大小的平方成正比。融合后的大 kernel 编译可能从毫秒级跳到秒级,对 JIT 编译(如 torch.compile)场景影响显著。

Roofline Model 的形式化

Cost model 的核心是 Roofline Model(Williams et al., 2009)。对于一个 kernel:

Texec=max(Tcompute,Tmemory)T_{\text{exec}} = \max(T_{\text{compute}}, T_{\text{memory}})

其中:

Tcompute=FLOPsPeak FLOPS×f(occupancy)T_{\text{compute}} = \frac{\text{FLOPs}}{\text{Peak FLOPS} \times f(\text{occupancy})} Tmemory=HBM bytesHBM bandwidthT_{\text{memory}} = \frac{\text{HBM bytes}}{\text{HBM bandwidth}}

f(occupancy)f(\text{occupancy}) 是一个单调递增函数,反映 occupancy 对实际吞吐量的影响。一般近似为 f(o)=max(0.25,o)f(o) = \max(0.25, o),即占用率低于 25% 时吞吐不再线性下降(因为指令级并行仍能提供一些利用率)。

下面的交互组件让你亲手调整硬件参数,观察不同融合决策的效果。

Cost Model 计算器选择 GPU:V100A100H100带宽: 2 TB/s算力: 312 TFLOPS | Shared Mem: 164 KBGELU + Dropout(融合有利)大 Reduction + 小 Pointwise(融合有害)MatMul + BiasAdd + ReLU(权衡)未融合GELUFLOPs (M):2HBM 读 (MB):4HBM 写 (MB):4Reg/Thread:16Shared Mem:Occupancy:100%估算时间:4.00 μs访存受限DropoutFLOPs (M):1HBM 读 (MB):4HBM 写 (MB):4Reg/Thread:12Shared Mem:Occupancy:100%估算时间:4.00 μs访存受限已融合GELU+DropoutFLOPs (M):3HBM 读 (MB):4HBM 写 (MB):4Reg/Thread:24Shared Mem:Occupancy:100%估算时间:4.00 μs访存受限估算时间未融合8.00 μs已融合4.00 μsOccupancy未融合100%已融合100%HBM 总量未融合16 MB已融合8 MB判定 融合有利100% 快两个 memory-bound pointwise op融合后消除中间 tensor 的 HBM 读写(4+4=8 MB),FLOPs 不变总是值得融合

TorchInductor 的 Cost Model 实战

PyTorch 2 的 TorchInductor(Ansel et al., 2024)采用启发式 cost model——不是精确建模,而是基于经验规则做决策。

核心融合启发式

Pointwise 融合(几乎总是做):

Inductor 对连续的 pointwise 操作(element-wise、broadcast)默认执行融合。理由简单:pointwise 操作不需要 shared memory,register 增量通常很小,而消除中间 tensor 的 HBM 读写几乎总是净正收益。

Reduction + Pointwise 融合(有条件):

对于 reduction(如 summax)后面紧跟 pointwise 操作,Inductor 会检查:

  • Reduction 维度是否足够小(不会导致 shared memory 溢出)
  • 融合后的 register 用量是否可控

如果 reduction 跨越了很大的维度(如 [batch, seq_len, hidden_dim] 上的 hidden_dim reduction),融合可能导致 shared memory 需求过高。

MatMul + Epilogue 融合(高价值):

GEMM 操作的 epilogue fusion 是最有价值的融合之一。在 GEMM tile 写回 HBM 之前,直接在寄存器中做 bias add、activation(ReLU/GELU)、dropout 等操作。cuBLAS 和 CUTLASS 都原生支持 epilogue fusion,Triton 则通过代码生成实现。

融合大小控制

# 控制单个融合 group 中最多包含多少个 node
torch._inductor.config.max_fusion_size = 64  # 默认值

# 控制 pointwise 融合的最大 node 数
torch._inductor.config.max_pointwise_cat_size = 8

调试融合决策

实际开发中,理解编译器的融合决策至关重要:

import torch

# 方法 1: 设置 trace 环境变量
# TORCHINDUCTOR_TRACE=1 python my_script.py

# 方法 2: 在代码中启用
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.graph_diagram = True  # 生成融合前后的图

# 方法 3: 查看生成的 Triton kernel
torch._inductor.config.debug = True
# 生成的 kernel 代码在 /tmp/torchinductor_<user>/ 目录下

查看 trace 输出可以发现类似这样的融合日志:

[FUSION] fused pointwise: relu + mul + add → fused_kernel_0 (3 nodes)
[FUSION] skipped: layernorm + large_epilogue (register pressure: 52 > threshold 48)

MLIR 级别的 Fusion

MLIR(Multi-Level Intermediate Representation)提供了比 Inductor 更原则性的融合方法。

Linalg Dialect 的 Fusion

MLIR 的 Linalg dialect 将张量运算表示为结构化操作(structured ops),天然支持 producer-consumer 融合分析。核心操作 linalg.fuse_into_containing_op 将 producer 的计算内联到 consumer 的循环体内。

Tile-and-Fuse:核心策略

MLIR 的融合策略遵循一个关键原则:先 tile,再 fuse

  1. Tile consumer:将 consumer 的计算拆分为适合目标内存层级的 tile(如 L1 cache 或 shared memory)
  2. Fuse producer into tile:将 producer 的计算融合到 consumer 的 tile 循环中
  3. 保证 working set 适配:由于 tile 大小由内存容量决定,融合后的 working set 天然不会溢出

这与 Inductor 的”先融合、再祈祷不溢出”形成鲜明对比。Tile-and-fuse 是正确性优先的方法——它从内存约束出发设计融合,而非事后检查。

Affine Fusion

对于完美仿射循环嵌套(perfect affine loop nests),MLIR 的 affine dialect 支持基于 polyhedral analysis 的循环融合。这种方法可以自动发现最优的循环融合顺序和 tile 大小,但仅适用于静态形状、仿射索引的场景。

对比总结

维度TorchInductorMLIR
方法启发式规则Tile-and-fuse + polyhedral
优点快速、实用、覆盖常见 pattern原则性强、正确性保证
缺点可能错过优化或做出次优决策编译开销较大、动态形状支持有限
适用JIT 编译(torch.compileAOT 编译(部署优化)

FlashAttention 深度剖析

FlashAttention(Dao et al., 2022)是 ML 系统领域最具影响力的优化之一。它不是通用融合——而是一个领域特定的算法改写

标准 Attention 的内存瓶颈

标准 Self-Attention 的计算流程:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

其中 Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d}NN 为序列长度,dd 为 head dimension)。

关键瓶颈在 S=QKTS = QK^T——一个 N×NN \times N 的矩阵。以 N=4096N = 4096, d=64d = 64, FP16 为例:

S=N2×2 bytes=40962×2=33,554,432 bytes32 MB|S| = N^2 \times 2 \text{ bytes} = 4096^2 \times 2 = 33{,}554{,}432 \text{ bytes} \approx 32 \text{ MB}

整个标准 attention 的 HBM 访问量:

  1. QQ: N×d×2=0.5N \times d \times 2 = 0.5 MB
  2. KK: N×d×2=0.5N \times d \times 2 = 0.5 MB
  3. SS: N2×2=32N^2 \times 2 = 32 MB → HBM
  4. SS: 3232 MB ← HBM(做 softmax)
  5. 写 softmax(SS): 3232 MB → HBM
  6. 读 softmax(SS): 3232 MB ← HBM(乘以 VV
  7. VV: 0.50.5 MB
  8. OO: 0.50.5 MB

总 HBM 访问:130\approx 130 MB,其中 128128 MB 都花在了 N×NN \times N 的分数矩阵上。

I/O 复杂度:Θ(Nd+N2)\Theta(Nd + N^2)。当 NdN \gg d 时(通常如此),N2N^2 项主导。

FlashAttention 的核心思路

FlashAttention 的关键洞察:不需要将整个 N×NN \times N 矩阵具象化(materialize)。下图展示了其分块策略——将注意力矩阵分成小 tile 在 SRAM 中计算,避免将完整 N×NN \times N 矩阵写入 HBM:

FlashAttention 分块策略
FlashAttention 分块策略注意力矩阵 S = QKᵀN × NBᵣ×Bᶜ从不写入 HBMQKᵀSRAM (快速、容量小)加载 Tile局部 QKᵀOnlineSoftmax累加 ×VHBM (慢速、容量大)Q [N,d]K [N,d]V [N,d]O [N,d]加载 tile写回 O关键洞察O(N²) 内存 → O(N) 内存

QQ, KK, VV 按行分块(tile):

Br=Bc=M4d2B_r = B_c = \left\lfloor \frac{M}{4d \cdot 2} \right\rfloor

其中 MM 是 SRAM(shared memory)大小(字节),dd 是 head dimension,×2\times 2 因为 FP16。

算法框架:

对每个 Q tile (Br 行):
    初始化 output 累加器 O_tile = 0, 运行最大值 m = -∞, 运行求和 l = 0
    对每个 K,V tile (Bc 行):
        从 HBM 加载 Q_tile, K_tile, V_tile 到 SRAM
        在 SRAM 中计算 S_tile = Q_tile × K_tile^T  (大小 Br × Bc,完全在 SRAM 内!)
        计算本地 softmax: m_new, l_new, P_tile
        用 online softmax rescaling 更新 O_tile
    将 O_tile 写回 HBM

每次迭代中,SRAM 中只存在一个 Br×BcB_r \times B_c 的小矩阵,而非完整的 N×NN \times N

Online Softmax:关键技巧

Softmax 本质上是全局操作——需要知道整行的最大值才能计算。FlashAttention 使用 online softmax(Milakov & Gimelshein, 2018)解决此问题:

对于向量 x=[x1,x2,,xN]x = [x_1, x_2, \ldots, x_N] 的 softmax σ(x)i=exijexj\sigma(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}},可以分块计算:

  1. 处理第 1 块:m1=max(x1:B)m_1 = \max(x_{1:B})l1=i=1Bexim1l_1 = \sum_{i=1}^{B} e^{x_i - m_1}
  2. 处理第 2 块:m2=max(m1,max(xB+1:2B))m_2 = \max(m_1, \max(x_{B+1:2B}))l2=l1em1m2+i=B+12Bexim2l_2 = l_1 \cdot e^{m_1 - m_2} + \sum_{i=B+1}^{2B} e^{x_i - m_2}
  3. 以此类推,每步都用新的 max 值 rescale 之前的累加结果

这样无需存储完整的 NN 维向量,就能正确计算 softmax。

I/O 复杂度分析

FlashAttention 的 HBM 访问量:

外层循环 N/Br\lceil N / B_r \rceil 次,每次:

  • 加载 QQ tile:Br×d×2B_r \times d \times 2 字节
  • 内层循环 N/Bc\lceil N / B_c \rceil 次,每次加载 KK tile + VV tile:2×Bc×d×22 \times B_c \times d \times 2 字节
  • 写回 OO tile:Br×d×2B_r \times d \times 2 字节

总 HBM 访问:

HBM=O(NBr(Brd+NBc2Bcd+Brd))=O(N2dBc+Nd)\text{HBM} = O\left(\frac{N}{B_r} \cdot \left(B_r d + \frac{N}{B_c} \cdot 2B_c d + B_r d\right)\right) = O\left(\frac{N^2 d}{B_c} + Nd\right)

由于 Bc=Θ(M/d)B_c = \Theta(M / d)

HBM=O(N2d2M+Nd)\text{HBM} = O\left(\frac{N^2 d^2}{M} + Nd\right)

M=Θ(Nd)M = \Theta(Nd) 时(SRAM 足够大),简化为 O(Nd)O(Nd)——与序列长度 NN 成线性关系

对比标准 attention 的 Θ(Nd+N2)\Theta(Nd + N^2),在长序列上 FlashAttention 的 I/O 效率是量级级别的提升。

FlashAttention-2:更少的非矩阵乘法 FLOPs

FlashAttention-2(Dao, 2023)的关键改进:

  1. 将 rescaling 移出内层循环:减少非 matmul FLOPs。在 GPU 上 Tensor Core 只加速 matmul;其他操作用 CUDA core,速度慢得多。FlashAttention-1 中每步都做 rescaling,产生大量 non-matmul FLOPs。FA-2 延迟 rescaling 到外层循环末尾。

  2. 更好的 warp 分区:FA-1 在 dd 维度上分 warp(每个 warp 处理 head 的一部分),需要 cross-warp reduction。FA-2 改为在 NN 维度上分 warp(每个 warp 处理不同的 K/V block),无需 cross-warp 通信。

结果:FA-2 达到理论 FLOPs 利用率的 50-73%(A100 上),相比 FA-1 的 25-40% 大幅提升。

FlashAttention-3:Hopper 架构特化

FlashAttention-3(Shah et al., 2024)为 NVIDIA Hopper(H100)架构深度优化:

  1. Warp-specialized pipeline:利用 Hopper 的异步执行特性,producer warps 负责 HBM→SRAM 数据搬运,consumer warps 负责计算,两者流水线并行。
  2. FP8 支持:Hopper 的 FP8 Tensor Core 提供双倍 FP16 的吞吐。FA-3 支持 FP8 计算并配合 incoherent processing(非相干处理)保证数值精度。
  3. Block quantization:将 softmax 中间结果量化为 FP8,减少 register 和 shared memory 占用。

为什么 FlashAttention 不是”算子融合”

值得强调:FlashAttention 不是通用编译器能自动发现的融合。它需要:

  • 理解 softmax 的数学性质(可以 online 计算)
  • 设计新算法(tiled attention with online softmax rescaling)
  • 手工优化内存访问模式

没有任何通用的 pattern matching 或 cost model 能从 softmax(QKT)V\text{softmax}(QK^T)V 的算子图中自动推导出这个算法。这是领域特定的算法改写(domain-specific algorithmic rewrite),编译器与领域知识的交汇点。

FlashAttention 深度剖析序列长度 N:5121024204840968192SRAM 大小:164 KB48KB228KBTile 大小: 328 rows暂停标准 Attention分数矩阵 S = QK^T [2048×2048] = 8.0 MB整个 N×N 矩阵在 HBM 中K (N=2048)Q (N=2048)内存访问流程:1读 Q2读 K3写 S 到 HBM4读 S5写 softmax(S)6读 V7写 OFlashAttention分数矩阵 S = QK^T [2048×2048] — 仅 tile 在 SRAM 中SRAM tile: 328x328K (N=2048)Q (N=2048)内存访问流程:1加载 Q tile2加载 K,V tile3SRAM 计算4更新累加器5写 O tile+ online softmax 在 SRAM 内完成HBM 访问量标准 Attention33.0 MBFlashAttention4.5 MBI/O 复杂度标准 Attention: O(Nd + N^2)FlashAttention: O(N^2d^2/M)M = SRAM = 164 KB, d = 64节省86% (7.4x)N=2048, d=64: 标准 S 矩阵 8.0 MB。序列越长,FlashAttention 优势越明显。

融合效果实战对比

理论分析之后,让我们看实际数据。下面的基准对比展示了不同融合策略在典型 Transformer 配置上的表现(数据为教学用估算值)。

四种策略的层次:

  1. No Fusion:每个算子独立 kernel,基线
  2. Element-wise Only:仅融合 pointwise 操作(GELU+dropout、bias+add 等)
  3. Full Inductor:TorchInductor 的完整融合(包括 reduction fusion、epilogue fusion)
  4. Inductor + FlashAttn:在 Inductor 基础上加入 FlashAttention
融合策略基准对比GPT-2 Smallseq=1024 h=768 B=16LLaMA 7Bseq=2048 h=4096 B=1LLaMA 70Bseq=4096 h=8192 B=1吞吐量 (TFLOPS)延迟 (ms)峰值内存 (MB)HBM 访问 (GB)4585130175422214.510.84.8G3.8G3.2G2.4G6.23.82.51.8No FusionElement-wise OnlyFull InductorInductor + FlashAttn关键发现Full Inductor 带来最大单步提升(85→130 TFLOPS)。FlashAttention 额外贡献 35% 吞吐提升并降低 25% 峰值内存。数据为教学用估算值

分析

几个关键观察:

Element-wise 融合的普惠效果:在所有模型规模上,仅做 element-wise 融合就能带来 1.8-2.0x 的吞吐提升。这是”低垂果实”——简单、安全、几乎没有负面影响。

Full Inductor 的价值:在 element-wise 之上,Inductor 的 reduction fusion 和 epilogue fusion 额外贡献 1.5-1.7x 提升。这部分需要 cost model 来避免有害融合。

FlashAttention 随序列长度缩放

  • GPT-2 Small(seq=1024):FlashAttention 额外提升 31%
  • LLaMA 7B(seq=2048):额外提升 35%
  • LLaMA 70B(seq=4096):额外提升 41%

序列越长,N2N^2 的标准 attention 开销越大,FlashAttention 的线性 I/O 优势越明显。

峰值内存的持续下降:融合不仅提升速度,还减少内存占用。从 No Fusion 到 Inductor+FlashAttn,峰值内存下降 37-47%。这对大模型训练尤为关键——省下的内存可以用于更大 batch size 或更长序列。

总结

本文的核心教训:

  1. Cost model 是必需的。融合不是免费的——它可能增加 register pressure、降低 occupancy、增加编译时间。成熟的编译器需要量化分析来做融合决策。

  2. 启发式 vs. 原则性。TorchInductor 用快速启发式(适合 JIT 场景),MLIR 用 tile-and-fuse(适合 AOT 场景)。两者互补,不是互斥。

  3. 算法改写超越通用融合。FlashAttention 展示了领域知识的力量——通过理解 attention 的数学结构,实现了编译器无法自动发现的优化。未来的最佳系统将结合通用融合(编译器)和领域特定改写(库/算法)。

  4. 性能优化的层次:Element-wise fusion(基础)→ Cost-model-guided fusion(进阶)→ Algorithmic rewrite(大师级)。每个层次都有不可替代的价值。

延伸阅读

  • FlashAttention 论文三部曲FA-1FA-2FA-3。建议按顺序阅读,理解从基础 idea 到 Hopper 特化的演进。
  • Roofline Model 原始论文:Williams, Waterman & Patterson (2009), “Roofline: An Insightful Visual Performance Model for Multicore Architectures”。理解 compute-bound vs memory-bound 的经典框架。
  • MLIR Linalg fusion 文档官方 docs。Tile-and-fuse 的实现细节。
  • PyTorch 2 论文:Ansel et al. (2024),ACM DL。TorchInductor 融合策略的工程实现。
  • CUTLASS Epilogue Fusion:NVIDIA 的 CUTLASS 库提供了 GEMM epilogue fusion 的模板化实现,是理解 matmul+activation 融合的最佳参考。