本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

算子融合(上):融合类型学与判定算法

算子融合(上):融合类型学与判定算法

更新于 2026-04-23

查看全景图用户代码全景图计算图捕获IR 设计优化 Pass算子融合8. 融合类型学你在这里代码生成调度与执行硬件执行

简介:为什么融合是最重要的优化

在 ML 编译器的众多优化技术中,Operator Fusion(算子融合) 被普遍认为是影响最大、收益最明显的优化。PyTorch 2 论文(Ansel et al., 2024)明确指出,TorchInductor 的性能提升中,超过 80% 来自融合优化。

为什么融合如此重要?原因在于现代 GPU 的计算能力内存带宽严重不匹配。以 NVIDIA A100 为例:

  • 计算峰值:312 TFLOPS (FP16 with Tensor Core)
  • 内存带宽:2 TB/s (HBM2e)
  • 算力/带宽比:156:1

这意味着,对于大多数算子(尤其是 element-wise 操作),瓶颈不在计算,而在数据搬运。如果一个算子需要从 HBM(High Bandwidth Memory)读取 12 MB 数据、写回 12 MB 结果,即使计算本身只需 1 微秒,光是内存 I/O 就需要约 12 微秒(24 MB / 2 TB/s)。

Roofline Model:直观理解瓶颈

Roofline Model(Williams et al., 2009)是理解这一现象的经典工具。下图展示了 A100 GPU 上的 Roofline Model,常见算子按算术强度分布在不同区域:

Roofline Model (A100 GPU)
Roofline Model (A100 GPU)0.111010010000.1110100算术强度 Arithmetic Intensity (FLOPs/Byte)性能 (TFLOPS)2 TB/s HBM312 TFLOPS (FP16)脊点 (Ridge Point)内存受限区计算受限区ReLU (逐元素)LayerNormSoftmaxGEMV (小 batch)GEMM (大 batch)

Roofline Model 将算子的算术强度(Arithmetic Intensity, AI) 定义为:

AI=FLOPsMemory I/O (Bytes)\text{AI} = \frac{\text{FLOPs}}{\text{Memory I/O (Bytes)}}

对于一个 element-wise ReLU 算子(y = max(0, x)):

  • 每个元素 1 次比较 + 1 次条件赋值 ≈ 0 FLOPs(忽略分支预测)
  • 每个元素读 4 bytes (FP32) + 写 4 bytes = 8 bytes

因此 AI ≈ 0。这种算子完全受限于内存带宽(memory-bound),GPU 的算力几乎完全闲置。

融合的本质:将多个算子合并为一个 kernel,减少中间结果的 HBM 往返,从而大幅降低内存流量,提升 AI,使硬件算力得到更充分利用。

融合类型学:五大模式

根据融合模式和实现难度,算子融合可分为简单融合(Simple Fusion)复杂融合(Complex Fusion) 两大类。下图概览了四种典型融合模式的 before→after 变换:

算子融合模式一览
算子融合模式一览逐元素融合内存受限AddReLU融合FusedAddReLU归约融合内存受限MatMulAdd bias融合FusedLinear广播融合内存受限MeanBroadcastSub融合FusedCenter算法级改写计算受限Q×KᵀScaleMaskSoftmax×V融合FlashAttention内存受限计算受限

下面交互演示展示了五种典型融合类型的详细动态。

算子融合类型学简单融合Element-wise 融合Reduction 融合Broadcast 融合Transpose/Reshape 消除复杂融合FlashAttention(算法改写)融合前xReLU 12 12× 12 12+ 12 12output融合后xReLU+×+add 12 12outputElement-wise 融合逐元素算子链,消除中间缓冲区融合前:3 次内核启动,(12+12) MB HBM 读写 × 3 = 72 MB。融合后:1 次启动,12+12 = 24 MB。节省 67% 内存流量。节省: 67% (7224 MB)

1. Element-wise Fusion(逐元素融合)

模式:将多个逐元素操作(reluaddmultanh 等)融合为单个 kernel。

示例y = (relu(x) * alpha + beta)

优化前

  1. t1 = relu(x) — 12 MB 读 + 12 MB 写
  2. t2 = t1 * alpha — 12 MB 读 + 12 MB 写
  3. y = t2 + beta — 12 MB 读 + 12 MB 写
  • 总计:72 MB HBM I/O,3 次 kernel 启动

优化后

  • 单个 kernel 直接计算 y = relu(x) * alpha + beta
  • 总计:24 MB HBM I/O(12 MB 读 + 12 MB 写)
  • 节省:67% 内存流量

实现要点

  • 每个 thread 负责一个元素,直接在寄存器中完成所有计算,无需写回中间结果。
  • TorchInductor 和 XLA 都大量使用此模式。

2. Reduction Fusion(归约融合)

模式:将 reduction 算子(summaxmean)与其前后的 element-wise 操作融合。

示例:L2 范数计算 norm = sqrt(sum(x**2))

优化前

  1. t1 = x**2 — 12 MB 读 + 12 MB 写
  2. t2 = sum(t1) — 12 MB 读 + 4 bytes 写
  3. norm = sqrt(t2) — 4 bytes 读 + 4 bytes 写

优化后

  • 单趟扫描(single-pass)计算 L2 范数,只写出标量结果
  • 节省:67% 内存流量

实现要点

  • 利用 warp-level 或 block-level reduce primitives(如 CUDA 的 __shfl_down_sync__syncthreads)。
  • 避免 materialize 中间张量。

3. Broadcast Fusion(广播融合)

模式:将 reduction + broadcast + element-wise 融合,核心模式为 reduce-then-apply

示例:LayerNorm 的 centering 步骤 y = x - mean(x)

优化前

  1. mean_val = reduce_mean(x) — 12 MB 读 + 4 bytes 写
  2. mean_broadcast = broadcast(mean_val) — 产生 12 MB 临时张量
  3. y = x - mean_broadcast — 24 MB 读 + 12 MB 写

优化后

  • 每个 thread 先计算全局均值(通过 shared memory 共享),再直接计算 x[i] - mean
  • 节省:60% 内存流量

实现要点

  • On-the-fly broadcasting:均值存储在寄存器或 shared memory,每个 thread 重复使用。
  • 这是 LayerNorm、BatchNorm、RMSNorm 等归一化算子的核心优化。

4. Transpose/Reshape Elimination(布局变换消除)

模式:通过 stride manipulation(步幅操作) 消除显式的 transpose 或 reshape 算子。

示例y = matmul(reshape(transpose(x)), W)

优化前

  1. t1 = transpose(x) — 需要 materialize 转置结果(12 MB)
  2. t2 = reshape(t1) — 可能需要拷贝(12 MB)
  3. y = matmul(t2, W) — 读 24 MB + 写 12 MB

优化后

  • matmul kernel 直接按照 transpose + reshape 后的 stride 读取 x
  • 节省:71% 内存流量

实现要点

  • 现代 CUDA kernel(如 CUTLASS)支持任意 stride 的输入读取。
  • 编译器通过 stride propagation 将 layout 信息传递给 consumer。

5. FlashAttention:算法级改写

前四种是模式级融合(pattern fusion),而 FlashAttention(Dao et al., 2022)是算法级融合(algorithmic fusion),需要改变计算顺序。

标准 Attention 计算

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

问题:中间矩阵 S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N} 占用大量显存。

  • N=4096N = 4096, FP16 为例:SS 占用 32 MB
  • Transformer 的 GPU 显存瓶颈往往在此

FlashAttention 策略

  1. Tiling:将 QQKKVV 分块加载到 SRAM(shared memory / L2 cache)
  2. Online Softmax:利用增量 softmax 技巧,在 tile 内完成 softmax 计算,无需 materialize 完整 SS 矩阵
  3. I/O Complexity:从 O(N2)O(N^2) 降至 O ⁣(N2d2M)O\!\left(\frac{N^2 d^2}{M}\right),其中 MM 是 SRAM 大小

性能提升:在长序列场景(N>2048N > 2048),FlashAttention 相比标准实现快 2-4 倍,且显存占用降低 10-20 倍

融合合法性分析:五大规则

并非任意两个算子都能融合。编译器需要检查以下五项条件:

融合合法性检查器点击两个节点检查能否融合x[128,768]W₁[768,3072]x@W₁[128,3072]GELU[128,3072]W₂[3072,768]GELU@W₂[128,768]dropout[128,768]+residual[128,768]LayerNorm[128,768]output[128,768]点击两个节点检查能否融合

1. Producer-Consumer Relationship(生产者-消费者关系)

规则:只能融合相邻的算子对,即一个算子的输出是另一个算子的输入。

反例mm1mm2 没有直接依赖关系,无法融合。

2. No Cycle(无循环依赖)

规则:融合后不能产生循环依赖(cycle),否则违反 DAG(有向无环图)的拓扑序。

检测方法:DFS 或拓扑排序。

3. Shape Compatibility(形状兼容)

规则:融合的算子必须能共享同一 iteration domain(迭代域)

兼容场景

  • Element-wise ops:形状完全相同
  • Broadcast ops:可通过 broadcasting 对齐
  • Reduction ops:可通过 tiling 分块计算

不兼容场景

  • matmul 的输出形状 [M, N]layer_norm 的输入形状 [B, S, D] 无法直接对齐(需 reshape)

4. No Side Effects(无副作用)

规则:融合的算子不能有观察者可见的副作用(observable side effects)

常见副作用算子

  • dropout:依赖随机数生成器状态
  • in-place 操作:修改输入张量
  • I/O 操作:打印、日志、checkpoint

处理策略

  • 在 inference 模式下,dropout 退化为 identity,可融合
  • 编译器需标记副作用,避免跨边界融合

5. Memory Fits in SRAM(内存可容纳)

规则:融合后的 kernel 的中间结果总和必须能放入 SRAM(CUDA 的 shared memory + register file)。

A100 SRAM 限制

  • Shared memory per SM:164 KB(动态配置)
  • Register file per SM:256 KB
  • 实际可用:~48 KB(考虑 bank conflict、occupancy)

示例

  • mm1 (12 KB) + gelu (2 KB) = 14 KB ✅
  • mm1 (12 KB) + mm2 (12 KB) = 24 KB ⚠️(需 tiling)
  • mm1 (32 KB) + softmax (32 KB) = 64 KB ❌(超出限制)

融合判定算法:Greedy vs. Graph Coloring

编译器需要自动决定哪些算子应该融合。主流方法有两种:

1. Greedy Fusion(贪心融合)— TorchInductor

策略:按拓扑序遍历计算图,遇到可融合的 edge 就立即融合。

算法伪代码

def greedy_fusion(graph):
    groups = {node: {node} for node in graph.nodes}
    for edge in topological_order(graph.edges):
        producer, consumer = edge
        if can_fuse(producer, consumer):
            groups[consumer] = groups[producer] | groups[consumer]
    return groups

优点

  • 简单快速,O(E)O(E) 复杂度
  • 易于实现,调试友好

缺点

  • 局部最优,可能错过全局更优的融合方案
  • 对 graph 的 traversal 顺序敏感
步骤 0
融合判定算法演示:Greedy Fusionx [128×768]LayerNormW₁x@W₁GELUW₂GELU@W₂+residualoutput初始化初始化:每个节点各自为一组

2. Graph Coloring(图着色)— XLA

策略:将融合问题建模为图着色问题

  • 每个算子是一个节点
  • 两个算子不能融合时,连一条边
  • 目标:用最少的颜色给所有节点着色(同色 = 同一融合组)

优化目标

  • 最小化内存流量(通过启发式函数估计)
  • 最小化 kernel 启动次数(颜色数量)

优点

  • 理论上能找到全局更优解(在 NP-hard 框架下近似求解)
  • 支持复杂的约束(如 memory budget、occupancy)

缺点

  • 编译时间较长(O(V2)O(V^2)O(V3)O(V^3)
  • 实现复杂,调试困难

TorchInductor vs. XLA

维度TorchInductorXLA
算法GreedyGraph Coloring
编译速度快(< 100ms)慢(> 1s)
融合质量良好(覆盖 90% 常见模式)优秀(但对罕见模式提升有限)
适用场景Eager mode, JITAOT, static graph

PyTorch 2 的选择是编译速度优先,因为在 Eager 执行模式下,编译开销会直接影响用户体验。XLA 则面向 AOT 编译,可以承受更长的编译时间。

TorchInductor 融合实现细节

TorchInductor 融合流水线FX Graph输入计算图ATen 算子序列Fusion Engine贪心分组合法性检查 + 分组Triton IR生成 KernelBlock-level 模板Compiled Kernel编译输出PTX + launch config示例:融合三算子matmul+ biasrelu融合单 Triton Kernel数据层次:RegistersSRAMHBM融合使中间结果留在寄存器/SRAM,避免 HBM 往返

Fusion Group 的 Lowering

融合后的 group 通过 Triton IR 生成单个 kernel:

# Fusion group: relu → mul → add
@triton.jit
def fused_kernel(x_ptr, out_ptr, alpha, beta, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    
    # Load
    x = tl.load(x_ptr + offsets, mask=mask)
    
    # Fused computation (in registers)
    t1 = tl.maximum(x, 0.0)  # relu
    t2 = t1 * alpha          # mul
    out = t2 + beta          # add
    
    # Store
    tl.store(out_ptr + offsets, out, mask=mask)

关键点

  • 所有中间结果(t1t2)存储在寄存器中,无需写回 HBM
  • Triton 的 tl.load / tl.store 自动处理内存对齐、coalescing
  • BLOCK_SIZE 是编译时常量,Triton 会自动展开循环

Scheduler:决定 Kernel 启动顺序

融合后的 graph 可能包含多个独立的 fusion group。Scheduler 负责:

  1. Topological Sort:确保依赖关系正确
  2. Memory Planning:决定何时分配/释放中间 buffer
  3. Concurrent Execution:利用 CUDA stream 实现并行

TorchInductor 使用 Dynamic Scheduler,在运行时根据实际 tensor shape 动态调整。

总结:融合是 Memory-Bound 问题的银弹

本文系统介绍了算子融合的类型学合法性分析判定算法。核心要点:

  1. 为什么融合最重要:现代 GPU 的算力/带宽比高达 156:1,大多数算子是 memory-bound
  2. 五大融合模式:Element-wise、Reduction、Broadcast、Transpose、FlashAttention
  3. 五大合法性规则:Producer-consumer、No cycle、Shape compatible、No side effect、Memory fits
  4. 两大判定算法:Greedy(TorchInductor,快)、Graph Coloring(XLA,优)

下一篇文章将深入 Cost Model(代价模型),探讨编译器如何量化融合的收益,以及如何在复杂场景下做出最优决策。

延伸阅读

  • FlashAttention 原理:Dao et al. (2022) 的原论文详细推导了 tiling 和 online softmax 的算法
  • XLA Fusion 实现:TensorFlow XLA 的 HLO Fusion Pass 源码(xla/service/gpu/gpu_fusible.cc
  • Triton 语言:OpenAI 的 Triton 提供了比 CUDA 更易用的 kernel 编写方式
  • Roofline Model 工具:NERSC 提供的 Roofline Toolkit 可以自动分析算子的 AI 和瓶颈