算子融合(上):融合类型学与判定算法
更新于 2026-04-23
简介:为什么融合是最重要的优化
在 ML 编译器的众多优化技术中,Operator Fusion(算子融合) 被普遍认为是影响最大、收益最明显的优化。PyTorch 2 论文(Ansel et al., 2024)明确指出,TorchInductor 的性能提升中,超过 80% 来自融合优化。
为什么融合如此重要?原因在于现代 GPU 的计算能力和内存带宽严重不匹配。以 NVIDIA A100 为例:
- 计算峰值:312 TFLOPS (FP16 with Tensor Core)
- 内存带宽:2 TB/s (HBM2e)
- 算力/带宽比:156:1
这意味着,对于大多数算子(尤其是 element-wise 操作),瓶颈不在计算,而在数据搬运。如果一个算子需要从 HBM(High Bandwidth Memory)读取 12 MB 数据、写回 12 MB 结果,即使计算本身只需 1 微秒,光是内存 I/O 就需要约 12 微秒(24 MB / 2 TB/s)。
Roofline Model:直观理解瓶颈
Roofline Model(Williams et al., 2009)是理解这一现象的经典工具。下图展示了 A100 GPU 上的 Roofline Model,常见算子按算术强度分布在不同区域:
Roofline Model 将算子的算术强度(Arithmetic Intensity, AI) 定义为:
对于一个 element-wise ReLU 算子(y = max(0, x)):
- 每个元素 1 次比较 + 1 次条件赋值 ≈ 0 FLOPs(忽略分支预测)
- 每个元素读 4 bytes (FP32) + 写 4 bytes = 8 bytes
因此 AI ≈ 0。这种算子完全受限于内存带宽(memory-bound),GPU 的算力几乎完全闲置。
融合的本质:将多个算子合并为一个 kernel,减少中间结果的 HBM 往返,从而大幅降低内存流量,提升 AI,使硬件算力得到更充分利用。
融合类型学:五大模式
根据融合模式和实现难度,算子融合可分为简单融合(Simple Fusion) 和复杂融合(Complex Fusion) 两大类。下图概览了四种典型融合模式的 before→after 变换:
下面交互演示展示了五种典型融合类型的详细动态。
1. Element-wise Fusion(逐元素融合)
模式:将多个逐元素操作(relu、add、mul、tanh 等)融合为单个 kernel。
示例:y = (relu(x) * alpha + beta)
优化前:
t1 = relu(x)— 12 MB 读 + 12 MB 写t2 = t1 * alpha— 12 MB 读 + 12 MB 写y = t2 + beta— 12 MB 读 + 12 MB 写
- 总计:72 MB HBM I/O,3 次 kernel 启动
优化后:
- 单个 kernel 直接计算
y = relu(x) * alpha + beta - 总计:24 MB HBM I/O(12 MB 读 + 12 MB 写)
- 节省:67% 内存流量
实现要点:
- 每个 thread 负责一个元素,直接在寄存器中完成所有计算,无需写回中间结果。
- TorchInductor 和 XLA 都大量使用此模式。
2. Reduction Fusion(归约融合)
模式:将 reduction 算子(sum、max、mean)与其前后的 element-wise 操作融合。
示例:L2 范数计算 norm = sqrt(sum(x**2))
优化前:
t1 = x**2— 12 MB 读 + 12 MB 写t2 = sum(t1)— 12 MB 读 + 4 bytes 写norm = sqrt(t2)— 4 bytes 读 + 4 bytes 写
优化后:
- 单趟扫描(single-pass)计算 L2 范数,只写出标量结果
- 节省:67% 内存流量
实现要点:
- 利用 warp-level 或 block-level reduce primitives(如 CUDA 的
__shfl_down_sync、__syncthreads)。 - 避免 materialize 中间张量。
3. Broadcast Fusion(广播融合)
模式:将 reduction + broadcast + element-wise 融合,核心模式为 reduce-then-apply。
示例:LayerNorm 的 centering 步骤 y = x - mean(x)
优化前:
mean_val = reduce_mean(x)— 12 MB 读 + 4 bytes 写mean_broadcast = broadcast(mean_val)— 产生 12 MB 临时张量y = x - mean_broadcast— 24 MB 读 + 12 MB 写
优化后:
- 每个 thread 先计算全局均值(通过 shared memory 共享),再直接计算
x[i] - mean - 节省:60% 内存流量
实现要点:
- On-the-fly broadcasting:均值存储在寄存器或 shared memory,每个 thread 重复使用。
- 这是 LayerNorm、BatchNorm、RMSNorm 等归一化算子的核心优化。
4. Transpose/Reshape Elimination(布局变换消除)
模式:通过 stride manipulation(步幅操作) 消除显式的 transpose 或 reshape 算子。
示例:y = matmul(reshape(transpose(x)), W)
优化前:
t1 = transpose(x)— 需要 materialize 转置结果(12 MB)t2 = reshape(t1)— 可能需要拷贝(12 MB)y = matmul(t2, W)— 读 24 MB + 写 12 MB
优化后:
- matmul kernel 直接按照 transpose + reshape 后的 stride 读取
x - 节省:71% 内存流量
实现要点:
- 现代 CUDA kernel(如 CUTLASS)支持任意 stride 的输入读取。
- 编译器通过 stride propagation 将 layout 信息传递给 consumer。
5. FlashAttention:算法级改写
前四种是模式级融合(pattern fusion),而 FlashAttention(Dao et al., 2022)是算法级融合(algorithmic fusion),需要改变计算顺序。
标准 Attention 计算:
问题:中间矩阵 占用大量显存。
- 以 , FP16 为例: 占用 32 MB
- Transformer 的 GPU 显存瓶颈往往在此
FlashAttention 策略:
- Tiling:将 、、 分块加载到 SRAM(shared memory / L2 cache)
- Online Softmax:利用增量 softmax 技巧,在 tile 内完成 softmax 计算,无需 materialize 完整 矩阵
- I/O Complexity:从 降至 ,其中 是 SRAM 大小
性能提升:在长序列场景(),FlashAttention 相比标准实现快 2-4 倍,且显存占用降低 10-20 倍。
融合合法性分析:五大规则
并非任意两个算子都能融合。编译器需要检查以下五项条件:
1. Producer-Consumer Relationship(生产者-消费者关系)
规则:只能融合相邻的算子对,即一个算子的输出是另一个算子的输入。
反例:mm1 和 mm2 没有直接依赖关系,无法融合。
2. No Cycle(无循环依赖)
规则:融合后不能产生循环依赖(cycle),否则违反 DAG(有向无环图)的拓扑序。
检测方法:DFS 或拓扑排序。
3. Shape Compatibility(形状兼容)
规则:融合的算子必须能共享同一 iteration domain(迭代域)。
兼容场景:
- Element-wise ops:形状完全相同
- Broadcast ops:可通过 broadcasting 对齐
- Reduction ops:可通过 tiling 分块计算
不兼容场景:
matmul的输出形状[M, N]与layer_norm的输入形状[B, S, D]无法直接对齐(需 reshape)
4. No Side Effects(无副作用)
规则:融合的算子不能有观察者可见的副作用(observable side effects)。
常见副作用算子:
dropout:依赖随机数生成器状态in-place操作:修改输入张量- I/O 操作:打印、日志、checkpoint
处理策略:
- 在 inference 模式下,
dropout退化为 identity,可融合 - 编译器需标记副作用,避免跨边界融合
5. Memory Fits in SRAM(内存可容纳)
规则:融合后的 kernel 的中间结果总和必须能放入 SRAM(CUDA 的 shared memory + register file)。
A100 SRAM 限制:
- Shared memory per SM:164 KB(动态配置)
- Register file per SM:256 KB
- 实际可用:~48 KB(考虑 bank conflict、occupancy)
示例:
mm1(12 KB) +gelu(2 KB) = 14 KB ✅mm1(12 KB) +mm2(12 KB) = 24 KB ⚠️(需 tiling)mm1(32 KB) +softmax(32 KB) = 64 KB ❌(超出限制)
融合判定算法:Greedy vs. Graph Coloring
编译器需要自动决定哪些算子应该融合。主流方法有两种:
1. Greedy Fusion(贪心融合)— TorchInductor
策略:按拓扑序遍历计算图,遇到可融合的 edge 就立即融合。
算法伪代码:
def greedy_fusion(graph):
groups = {node: {node} for node in graph.nodes}
for edge in topological_order(graph.edges):
producer, consumer = edge
if can_fuse(producer, consumer):
groups[consumer] = groups[producer] | groups[consumer]
return groups
优点:
- 简单快速, 复杂度
- 易于实现,调试友好
缺点:
- 局部最优,可能错过全局更优的融合方案
- 对 graph 的 traversal 顺序敏感
2. Graph Coloring(图着色)— XLA
策略:将融合问题建模为图着色问题:
- 每个算子是一个节点
- 两个算子不能融合时,连一条边
- 目标:用最少的颜色给所有节点着色(同色 = 同一融合组)
优化目标:
- 最小化内存流量(通过启发式函数估计)
- 最小化 kernel 启动次数(颜色数量)
优点:
- 理论上能找到全局更优解(在 NP-hard 框架下近似求解)
- 支持复杂的约束(如 memory budget、occupancy)
缺点:
- 编译时间较长( 到 )
- 实现复杂,调试困难
TorchInductor vs. XLA
| 维度 | TorchInductor | XLA |
|---|---|---|
| 算法 | Greedy | Graph Coloring |
| 编译速度 | 快(< 100ms) | 慢(> 1s) |
| 融合质量 | 良好(覆盖 90% 常见模式) | 优秀(但对罕见模式提升有限) |
| 适用场景 | Eager mode, JIT | AOT, static graph |
PyTorch 2 的选择是编译速度优先,因为在 Eager 执行模式下,编译开销会直接影响用户体验。XLA 则面向 AOT 编译,可以承受更长的编译时间。
TorchInductor 融合实现细节
Fusion Group 的 Lowering
融合后的 group 通过 Triton IR 生成单个 kernel:
# Fusion group: relu → mul → add
@triton.jit
def fused_kernel(x_ptr, out_ptr, alpha, beta, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
# Load
x = tl.load(x_ptr + offsets, mask=mask)
# Fused computation (in registers)
t1 = tl.maximum(x, 0.0) # relu
t2 = t1 * alpha # mul
out = t2 + beta # add
# Store
tl.store(out_ptr + offsets, out, mask=mask)
关键点:
- 所有中间结果(
t1、t2)存储在寄存器中,无需写回 HBM - Triton 的
tl.load/tl.store自动处理内存对齐、coalescing BLOCK_SIZE是编译时常量,Triton 会自动展开循环
Scheduler:决定 Kernel 启动顺序
融合后的 graph 可能包含多个独立的 fusion group。Scheduler 负责:
- Topological Sort:确保依赖关系正确
- Memory Planning:决定何时分配/释放中间 buffer
- Concurrent Execution:利用 CUDA stream 实现并行
TorchInductor 使用 Dynamic Scheduler,在运行时根据实际 tensor shape 动态调整。
总结:融合是 Memory-Bound 问题的银弹
本文系统介绍了算子融合的类型学、合法性分析和判定算法。核心要点:
- 为什么融合最重要:现代 GPU 的算力/带宽比高达 156:1,大多数算子是 memory-bound
- 五大融合模式:Element-wise、Reduction、Broadcast、Transpose、FlashAttention
- 五大合法性规则:Producer-consumer、No cycle、Shape compatible、No side effect、Memory fits
- 两大判定算法:Greedy(TorchInductor,快)、Graph Coloring(XLA,优)
下一篇文章将深入 Cost Model(代价模型),探讨编译器如何量化融合的收益,以及如何在复杂场景下做出最优决策。
延伸阅读
- FlashAttention 原理:Dao et al. (2022) 的原论文详细推导了 tiling 和 online softmax 的算法
- XLA Fusion 实现:TensorFlow XLA 的 HLO Fusion Pass 源码(
xla/service/gpu/gpu_fusible.cc) - Triton 语言:OpenAI 的 Triton 提供了比 CUDA 更易用的 kernel 编写方式
- Roofline Model 工具:NERSC 提供的 Roofline Toolkit 可以自动分析算子的 AI 和瓶颈