算子融合(下):Cost Model 与融合实战
更新于 2026-04-23
简介
上一篇我们讨论了融合的 WHAT(五种融合类型)和 WHEN(合法性判定算法)。本文回答两个更关键的问题:
- WHETHER — 融合是否有益?不是所有合法的融合都值得做。
- HOW — 实践中如何实现高效融合?从 TorchInductor 的启发式到 FlashAttention 的算法改写。
核心观点:“能融合”不等于”该融合”。盲目融合可能导致 register pressure 增大、occupancy 下降、编译时间爆炸。成熟的编译器需要 cost model 来做出明智的融合决策。
Cost Model 设计
为什么不能”融合一切”?
融合并非越多越好。下图展示了融合程度与性能之间的典型关系——存在一个由 cost model 指导的最优区间:
假设我们有一个合法融合候选 A+B。从上篇我们知道它满足所有合法性条件。但融合后的 kernel 可能比两个独立 kernel 更慢。原因有三:
1. Register Pressure(寄存器压力)
每个 GPU 线程拥有有限的 register。NVIDIA GPU 的 SM(Streaming Multiprocessor)通常有 65536 个 32-bit register。如果一个 kernel 的每个线程需要 32 个 register,blockSize 为 256:
融合后如果 register 需求从 32 增加到 48:
blocks 从 8 降到 5,意味着 SM 上能同时运行的 warp 减少了 37.5%。
2. Occupancy(占用率)
Occupancy 是 GPU 性能的关键指标,定义为:
影响 occupancy 的三大约束:
- Register 文件大小:如上所述,register 用量越多,能启动的 warp 越少
- Shared memory 容量:V100 = 96 KB, A100 = 164 KB, H100 = 228 KB。如果一个 block 需要 32 KB shared memory,A100 上最多 5 个 block(164/32≈5)
- Max threads per block:硬件限制(通常 1024)
低 occupancy 意味着 GPU 难以通过 warp 切换来隐藏内存延迟(latency hiding),导致实际吞吐量远低于峰值。
3. 编译时间
Triton 和 CUDA kernel 的编译时间大致与 kernel 大小的平方成正比。融合后的大 kernel 编译可能从毫秒级跳到秒级,对 JIT 编译(如 torch.compile)场景影响显著。
Roofline Model 的形式化
Cost model 的核心是 Roofline Model(Williams et al., 2009)。对于一个 kernel:
其中:
是一个单调递增函数,反映 occupancy 对实际吞吐量的影响。一般近似为 ,即占用率低于 25% 时吞吐不再线性下降(因为指令级并行仍能提供一些利用率)。
下面的交互组件让你亲手调整硬件参数,观察不同融合决策的效果。
TorchInductor 的 Cost Model 实战
PyTorch 2 的 TorchInductor(Ansel et al., 2024)采用启发式 cost model——不是精确建模,而是基于经验规则做决策。
核心融合启发式
Pointwise 融合(几乎总是做):
Inductor 对连续的 pointwise 操作(element-wise、broadcast)默认执行融合。理由简单:pointwise 操作不需要 shared memory,register 增量通常很小,而消除中间 tensor 的 HBM 读写几乎总是净正收益。
Reduction + Pointwise 融合(有条件):
对于 reduction(如 sum、max)后面紧跟 pointwise 操作,Inductor 会检查:
- Reduction 维度是否足够小(不会导致 shared memory 溢出)
- 融合后的 register 用量是否可控
如果 reduction 跨越了很大的维度(如 [batch, seq_len, hidden_dim] 上的 hidden_dim reduction),融合可能导致 shared memory 需求过高。
MatMul + Epilogue 融合(高价值):
GEMM 操作的 epilogue fusion 是最有价值的融合之一。在 GEMM tile 写回 HBM 之前,直接在寄存器中做 bias add、activation(ReLU/GELU)、dropout 等操作。cuBLAS 和 CUTLASS 都原生支持 epilogue fusion,Triton 则通过代码生成实现。
融合大小控制
# 控制单个融合 group 中最多包含多少个 node
torch._inductor.config.max_fusion_size = 64 # 默认值
# 控制 pointwise 融合的最大 node 数
torch._inductor.config.max_pointwise_cat_size = 8
调试融合决策
实际开发中,理解编译器的融合决策至关重要:
import torch
# 方法 1: 设置 trace 环境变量
# TORCHINDUCTOR_TRACE=1 python my_script.py
# 方法 2: 在代码中启用
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.graph_diagram = True # 生成融合前后的图
# 方法 3: 查看生成的 Triton kernel
torch._inductor.config.debug = True
# 生成的 kernel 代码在 /tmp/torchinductor_<user>/ 目录下
查看 trace 输出可以发现类似这样的融合日志:
[FUSION] fused pointwise: relu + mul + add → fused_kernel_0 (3 nodes)
[FUSION] skipped: layernorm + large_epilogue (register pressure: 52 > threshold 48)
MLIR 级别的 Fusion
MLIR(Multi-Level Intermediate Representation)提供了比 Inductor 更原则性的融合方法。
Linalg Dialect 的 Fusion
MLIR 的 Linalg dialect 将张量运算表示为结构化操作(structured ops),天然支持 producer-consumer 融合分析。核心操作 linalg.fuse_into_containing_op 将 producer 的计算内联到 consumer 的循环体内。
Tile-and-Fuse:核心策略
MLIR 的融合策略遵循一个关键原则:先 tile,再 fuse。
- Tile consumer:将 consumer 的计算拆分为适合目标内存层级的 tile(如 L1 cache 或 shared memory)
- Fuse producer into tile:将 producer 的计算融合到 consumer 的 tile 循环中
- 保证 working set 适配:由于 tile 大小由内存容量决定,融合后的 working set 天然不会溢出
这与 Inductor 的”先融合、再祈祷不溢出”形成鲜明对比。Tile-and-fuse 是正确性优先的方法——它从内存约束出发设计融合,而非事后检查。
Affine Fusion
对于完美仿射循环嵌套(perfect affine loop nests),MLIR 的 affine dialect 支持基于 polyhedral analysis 的循环融合。这种方法可以自动发现最优的循环融合顺序和 tile 大小,但仅适用于静态形状、仿射索引的场景。
对比总结
| 维度 | TorchInductor | MLIR |
|---|---|---|
| 方法 | 启发式规则 | Tile-and-fuse + polyhedral |
| 优点 | 快速、实用、覆盖常见 pattern | 原则性强、正确性保证 |
| 缺点 | 可能错过优化或做出次优决策 | 编译开销较大、动态形状支持有限 |
| 适用 | JIT 编译(torch.compile) | AOT 编译(部署优化) |
FlashAttention 深度剖析
FlashAttention(Dao et al., 2022)是 ML 系统领域最具影响力的优化之一。它不是通用融合——而是一个领域特定的算法改写。
标准 Attention 的内存瓶颈
标准 Self-Attention 的计算流程:
其中 ( 为序列长度, 为 head dimension)。
关键瓶颈在 ——一个 的矩阵。以 , , FP16 为例:
整个标准 attention 的 HBM 访问量:
- 读 : MB
- 读 : MB
- 写 : MB → HBM
- 读 : MB ← HBM(做 softmax)
- 写 softmax(): MB → HBM
- 读 softmax(): MB ← HBM(乘以 )
- 读 : MB
- 写 : MB
总 HBM 访问: MB,其中 MB 都花在了 的分数矩阵上。
I/O 复杂度:。当 时(通常如此), 项主导。
FlashAttention 的核心思路
FlashAttention 的关键洞察:不需要将整个 矩阵具象化(materialize)。下图展示了其分块策略——将注意力矩阵分成小 tile 在 SRAM 中计算,避免将完整 矩阵写入 HBM:
将 , , 按行分块(tile):
其中 是 SRAM(shared memory)大小(字节), 是 head dimension, 因为 FP16。
算法框架:
对每个 Q tile (Br 行):
初始化 output 累加器 O_tile = 0, 运行最大值 m = -∞, 运行求和 l = 0
对每个 K,V tile (Bc 行):
从 HBM 加载 Q_tile, K_tile, V_tile 到 SRAM
在 SRAM 中计算 S_tile = Q_tile × K_tile^T (大小 Br × Bc,完全在 SRAM 内!)
计算本地 softmax: m_new, l_new, P_tile
用 online softmax rescaling 更新 O_tile
将 O_tile 写回 HBM
每次迭代中,SRAM 中只存在一个 的小矩阵,而非完整的 。
Online Softmax:关键技巧
Softmax 本质上是全局操作——需要知道整行的最大值才能计算。FlashAttention 使用 online softmax(Milakov & Gimelshein, 2018)解决此问题:
对于向量 的 softmax ,可以分块计算:
- 处理第 1 块:,
- 处理第 2 块:,
- 以此类推,每步都用新的 max 值 rescale 之前的累加结果
这样无需存储完整的 维向量,就能正确计算 softmax。
I/O 复杂度分析
FlashAttention 的 HBM 访问量:
外层循环 次,每次:
- 加载 tile: 字节
- 内层循环 次,每次加载 tile + tile: 字节
- 写回 tile: 字节
总 HBM 访问:
由于 :
当 时(SRAM 足够大),简化为 ——与序列长度 成线性关系!
对比标准 attention 的 ,在长序列上 FlashAttention 的 I/O 效率是量级级别的提升。
FlashAttention-2:更少的非矩阵乘法 FLOPs
FlashAttention-2(Dao, 2023)的关键改进:
-
将 rescaling 移出内层循环:减少非 matmul FLOPs。在 GPU 上 Tensor Core 只加速 matmul;其他操作用 CUDA core,速度慢得多。FlashAttention-1 中每步都做 rescaling,产生大量 non-matmul FLOPs。FA-2 延迟 rescaling 到外层循环末尾。
-
更好的 warp 分区:FA-1 在 维度上分 warp(每个 warp 处理 head 的一部分),需要 cross-warp reduction。FA-2 改为在 维度上分 warp(每个 warp 处理不同的 K/V block),无需 cross-warp 通信。
结果:FA-2 达到理论 FLOPs 利用率的 50-73%(A100 上),相比 FA-1 的 25-40% 大幅提升。
FlashAttention-3:Hopper 架构特化
FlashAttention-3(Shah et al., 2024)为 NVIDIA Hopper(H100)架构深度优化:
- Warp-specialized pipeline:利用 Hopper 的异步执行特性,producer warps 负责 HBM→SRAM 数据搬运,consumer warps 负责计算,两者流水线并行。
- FP8 支持:Hopper 的 FP8 Tensor Core 提供双倍 FP16 的吞吐。FA-3 支持 FP8 计算并配合 incoherent processing(非相干处理)保证数值精度。
- Block quantization:将 softmax 中间结果量化为 FP8,减少 register 和 shared memory 占用。
为什么 FlashAttention 不是”算子融合”
值得强调:FlashAttention 不是通用编译器能自动发现的融合。它需要:
- 理解 softmax 的数学性质(可以 online 计算)
- 设计新算法(tiled attention with online softmax rescaling)
- 手工优化内存访问模式
没有任何通用的 pattern matching 或 cost model 能从 的算子图中自动推导出这个算法。这是领域特定的算法改写(domain-specific algorithmic rewrite),编译器与领域知识的交汇点。
融合效果实战对比
理论分析之后,让我们看实际数据。下面的基准对比展示了不同融合策略在典型 Transformer 配置上的表现(数据为教学用估算值)。
四种策略的层次:
- No Fusion:每个算子独立 kernel,基线
- Element-wise Only:仅融合 pointwise 操作(GELU+dropout、bias+add 等)
- Full Inductor:TorchInductor 的完整融合(包括 reduction fusion、epilogue fusion)
- Inductor + FlashAttn:在 Inductor 基础上加入 FlashAttention
分析
几个关键观察:
Element-wise 融合的普惠效果:在所有模型规模上,仅做 element-wise 融合就能带来 1.8-2.0x 的吞吐提升。这是”低垂果实”——简单、安全、几乎没有负面影响。
Full Inductor 的价值:在 element-wise 之上,Inductor 的 reduction fusion 和 epilogue fusion 额外贡献 1.5-1.7x 提升。这部分需要 cost model 来避免有害融合。
FlashAttention 随序列长度缩放:
- GPT-2 Small(seq=1024):FlashAttention 额外提升 31%
- LLaMA 7B(seq=2048):额外提升 35%
- LLaMA 70B(seq=4096):额外提升 41%
序列越长, 的标准 attention 开销越大,FlashAttention 的线性 I/O 优势越明显。
峰值内存的持续下降:融合不仅提升速度,还减少内存占用。从 No Fusion 到 Inductor+FlashAttn,峰值内存下降 37-47%。这对大模型训练尤为关键——省下的内存可以用于更大 batch size 或更长序列。
总结
本文的核心教训:
-
Cost model 是必需的。融合不是免费的——它可能增加 register pressure、降低 occupancy、增加编译时间。成熟的编译器需要量化分析来做融合决策。
-
启发式 vs. 原则性。TorchInductor 用快速启发式(适合 JIT 场景),MLIR 用 tile-and-fuse(适合 AOT 场景)。两者互补,不是互斥。
-
算法改写超越通用融合。FlashAttention 展示了领域知识的力量——通过理解 attention 的数学结构,实现了编译器无法自动发现的优化。未来的最佳系统将结合通用融合(编译器)和领域特定改写(库/算法)。
-
性能优化的层次:Element-wise fusion(基础)→ Cost-model-guided fusion(进阶)→ Algorithmic rewrite(大师级)。每个层次都有不可替代的价值。
延伸阅读
- FlashAttention 论文三部曲:FA-1、FA-2、FA-3。建议按顺序阅读,理解从基础 idea 到 Hopper 特化的演进。
- Roofline Model 原始论文:Williams, Waterman & Patterson (2009), “Roofline: An Insightful Visual Performance Model for Multicore Architectures”。理解 compute-bound vs memory-bound 的经典框架。
- MLIR Linalg fusion 文档:官方 docs。Tile-and-fuse 的实现细节。
- PyTorch 2 论文:Ansel et al. (2024),ACM DL。TorchInductor 融合策略的工程实现。
- CUTLASS Epilogue Fusion:NVIDIA 的 CUTLASS 库提供了 GEMM epilogue fusion 的模板化实现,是理解 matmul+activation 融合的最佳参考。