Operator Fusion (Part I): Taxonomy & Decision Algorithms
Updated 2026-04-13
Introduction: Why Fusion is the Most Important Optimization
Among all ML compiler optimization techniques, Operator Fusion is widely recognized as the most impactful. The PyTorch 2 paper (Ansel et al., 2024) explicitly states that over 80% of TorchInductor’s performance gains come from fusion optimizations.
Why is fusion so critical? The answer lies in the severe mismatch between compute throughput and memory bandwidth in modern GPUs. Consider the NVIDIA A100:
- Peak compute: 312 TFLOPS (FP16 with Tensor Core)
- Memory bandwidth: 2 TB/s (HBM2e)
- Compute-to-bandwidth ratio: 156:1
This means for most operators (especially element-wise operations), the bottleneck is not computation but data movement. If an operator needs to read 12 MB from HBM (High Bandwidth Memory) and write back 12 MB, even if the computation takes only 1 microsecond, memory I/O alone requires about 12 microseconds (24 MB / 2 TB/s).
Roofline Model: Visualizing the Bottleneck
The Roofline Model (Williams et al., 2009) provides an intuitive framework. It defines Arithmetic Intensity (AI) as:
For an element-wise ReLU operator (y = max(0, x)):
- Each element: 1 comparison + 1 conditional assignment ≈ 0 FLOPs (ignoring branch prediction)
- Each element: read 4 bytes (FP32) + write 4 bytes = 8 bytes
Thus AI ≈ 0. Such operators are entirely memory-bound, leaving GPU compute units mostly idle.
Essence of fusion: Merge multiple operators into a single kernel, eliminating intermediate HBM round-trips, dramatically reducing memory traffic, increasing AI, and better utilizing hardware compute.
Fusion Taxonomy: Five Major Patterns
Based on fusion patterns and implementation complexity, operator fusion can be categorized into Simple Fusion and Complex Fusion. The interactive demo below showcases five typical fusion types.
1. Element-wise Fusion
Pattern: Fuse multiple element-wise operations (relu, add, mul, tanh, etc.) into a single kernel.
Example: y = (relu(x) * alpha + beta)
Before optimization:
t1 = relu(x)— 12 MB read + 12 MB writet2 = t1 * alpha— 12 MB read + 12 MB writey = t2 + beta— 12 MB read + 12 MB write
- Total: 72 MB HBM I/O, 3 kernel launches
After optimization:
- Single kernel computes
y = relu(x) * alpha + betadirectly - Total: 24 MB HBM I/O (12 MB read + 12 MB write)
- Savings: 67% memory traffic
Implementation key points:
- Each thread handles one element, performing all computations in registers without writing intermediate results back to HBM.
- Heavily used by both TorchInductor and XLA.
2. Reduction Fusion
Pattern: Fuse reduction operators (sum, max, mean) with surrounding element-wise operations.
Example: L2 norm computation norm = sqrt(sum(x**2))
Before optimization:
t1 = x**2— 12 MB read + 12 MB writet2 = sum(t1)— 12 MB read + 4 bytes writenorm = sqrt(t2)— 4 bytes read + 4 bytes write
After optimization:
- Single-pass computation of L2 norm, writing only scalar result
- Savings: 67% memory traffic
Implementation key points:
- Leverage warp-level or block-level reduce primitives (e.g., CUDA’s
__shfl_down_sync,__syncthreads). - Avoid materializing intermediate tensors.
3. Broadcast Fusion
Pattern: Fuse reduction + broadcast + element-wise, core pattern being reduce-then-apply.
Example: LayerNorm centering step y = x - mean(x)
Before optimization:
mean_val = reduce_mean(x)— 12 MB read + 4 bytes writemean_broadcast = broadcast(mean_val)— creates 12 MB temporary tensory = x - mean_broadcast— 24 MB read + 12 MB write
After optimization:
- Each thread first computes global mean (shared via shared memory), then directly computes
x[i] - mean - Savings: 60% memory traffic
Implementation key points:
- On-the-fly broadcasting: mean stored in register or shared memory, reused by each thread.
- Core optimization for LayerNorm, BatchNorm, RMSNorm, and other normalization operators.
4. Transpose/Reshape Elimination
Pattern: Eliminate explicit transpose or reshape operators via stride manipulation.
Example: y = matmul(reshape(transpose(x)), W)
Before optimization:
t1 = transpose(x)— must materialize transposed result (12 MB)t2 = reshape(t1)— may require copy (12 MB)y = matmul(t2, W)— read 24 MB + write 12 MB
After optimization:
- matmul kernel directly reads
xfollowing transpose + reshape strides - Savings: 71% memory traffic
Implementation key points:
- Modern CUDA kernels (e.g., CUTLASS) support arbitrary stride input reads.
- Compiler propagates layout information to consumers via stride propagation.
5. FlashAttention: Algorithmic Rewrite
The first four are pattern fusion, while FlashAttention (Dao et al., 2022) is algorithmic fusion, requiring computation reordering.
Standard Attention computation:
Problem: Intermediate matrix consumes significant memory.
- For , FP16: requires 32 MB
- Often the GPU memory bottleneck in Transformers
FlashAttention strategy:
- Tiling: Load , , blocks into SRAM (shared memory / L2 cache)
- Online Softmax: Use incremental softmax trick to complete softmax within tiles, avoiding full materialization
- I/O Complexity: Reduced from to , where is SRAM size
Performance gains: For long sequences (), FlashAttention is 2-4× faster than standard implementation, with 10-20× lower memory footprint.
Fusion Legality Analysis: Five Rules
Not all operator pairs can be fused. Compilers must check these five conditions:
1. Producer-Consumer Relationship
Rule: Only adjacent operator pairs can be fused, where one operator’s output is another’s input.
Counter-example: mm1 and mm2 have no direct dependency and cannot fuse.
2. No Cycle
Rule: Fusion must not introduce cycles, violating the DAG (Directed Acyclic Graph) topological order.
Detection method: DFS or topological sort.
3. Shape Compatibility
Rule: Fused operators must share the same iteration domain.
Compatible scenarios:
- Element-wise ops: identical shapes
- Broadcast ops: alignable via broadcasting
- Reduction ops: computable via tiling
Incompatible scenarios:
matmuloutput shape[M, N]vs.layer_norminput shape[B, S, D]cannot directly align (requires reshape)
4. No Side Effects
Rule: Fused operators must have no observable side effects.
Common side-effect operators:
dropout: depends on random number generator state- In-place operations: modify input tensors
- I/O operations: print, log, checkpoint
Handling strategy:
- In inference mode,
dropoutdegenerates to identity and can fuse - Compiler must mark side effects to prevent cross-boundary fusion
5. Memory Fits in SRAM
Rule: Fused kernel’s total intermediate results must fit in SRAM (CUDA’s shared memory + register file).
A100 SRAM limits:
- Shared memory per SM: 164 KB (dynamically configurable)
- Register file per SM: 256 KB
- Practical usable: ~48 KB (considering bank conflicts, occupancy)
Examples:
mm1(12 KB) +gelu(2 KB) = 14 KB ✅mm1(12 KB) +mm2(12 KB) = 24 KB ⚠️ (requires tiling)mm1(32 KB) +softmax(32 KB) = 64 KB ❌ (exceeds limit)
Fusion Decision Algorithms: Greedy vs. Graph Coloring
Compilers must automatically decide which operators to fuse. Two mainstream approaches:
1. Greedy Fusion — TorchInductor
Strategy: Traverse the computation graph in topological order, immediately fusing fusible edges.
Pseudocode:
def greedy_fusion(graph):
groups = {node: {node} for node in graph.nodes}
for edge in topological_order(graph.edges):
producer, consumer = edge
if can_fuse(producer, consumer):
groups[consumer] = groups[producer] | groups[consumer]
return groups
Advantages:
- Simple and fast, complexity
- Easy to implement and debug
Disadvantages:
- Locally optimal, may miss globally better fusion plans
- Sensitive to graph traversal order
2. Graph Coloring — XLA
Strategy: Model fusion as graph coloring problem:
- Each operator is a node
- Two operators that cannot fuse are connected by an edge
- Goal: color all nodes with minimum colors (same color = same fusion group)
Optimization objectives:
- Minimize memory traffic (estimated via heuristic functions)
- Minimize kernel launch count (number of colors)
Advantages:
- Theoretically finds globally better solutions (approximate solving within NP-hard framework)
- Supports complex constraints (e.g., memory budget, occupancy)
Disadvantages:
- Longer compile time ( to )
- Complex implementation, difficult to debug
TorchInductor vs. XLA
| Dimension | TorchInductor | XLA |
|---|---|---|
| Algorithm | Greedy | Graph Coloring |
| Compile Speed | Fast (< 100ms) | Slow (> 1s) |
| Fusion Quality | Good (covers 90% common patterns) | Excellent (but limited improvement on rare patterns) |
| Use Case | Eager mode, JIT | AOT, static graph |
PyTorch 2 prioritizes compile speed because in Eager execution mode, compilation overhead directly impacts user experience. XLA targets AOT compilation and can afford longer compile times.
TorchInductor Fusion Implementation Details
Lowering Fusion Groups
Fused groups generate single kernels via Triton IR:
# Fusion group: relu → mul → add
@triton.jit
def fused_kernel(x_ptr, out_ptr, alpha, beta, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
# Load
x = tl.load(x_ptr + offsets, mask=mask)
# Fused computation (in registers)
t1 = tl.maximum(x, 0.0) # relu
t2 = t1 * alpha # mul
out = t2 + beta # add
# Store
tl.store(out_ptr + offsets, out, mask=mask)
Key points:
- All intermediate results (
t1,t2) stored in registers, no HBM writeback - Triton’s
tl.load/tl.storeautomatically handle memory alignment and coalescing BLOCK_SIZEis compile-time constant, Triton auto-unrolls loops
Scheduler: Determining Kernel Launch Order
Fused graphs may contain multiple independent fusion groups. Scheduler handles:
- Topological Sort: Ensure correct dependencies
- Memory Planning: Decide when to allocate/free intermediate buffers
- Concurrent Execution: Leverage CUDA streams for parallelism
TorchInductor uses a Dynamic Scheduler, adjusting dynamically at runtime based on actual tensor shapes.
Summary: Fusion is the Silver Bullet for Memory-Bound Problems
This article systematically introduced operator fusion’s taxonomy, legality analysis, and decision algorithms. Key takeaways:
- Why fusion matters most: Modern GPUs have compute-to-bandwidth ratios of 156:1; most operators are memory-bound
- Five fusion patterns: Element-wise, Reduction, Broadcast, Transpose, FlashAttention
- Five legality rules: Producer-consumer, No cycle, Shape compatible, No side effect, Memory fits
- Two decision algorithms: Greedy (TorchInductor, fast), Graph Coloring (XLA, optimal)
The next article will dive into Cost Models, exploring how compilers quantify fusion benefits and make optimal decisions in complex scenarios.
Further Reading
- FlashAttention principles: Dao et al. (2022) paper details tiling and online softmax algorithms
- XLA Fusion implementation: TensorFlow XLA’s HLO Fusion Pass source code (
xla/service/gpu/gpu_fusible.cc) - Triton language: OpenAI’s Triton provides easier kernel authoring than CUDA
- Roofline Model tools: NERSC’s Roofline Toolkit automatically analyzes operator AI and bottlenecks