Content on this site is AI-generated and may contain errors. If you find issues, please report at GitHub Issues .

Operator Fusion (Part I): Taxonomy & Decision Algorithms

Operator Fusion (Part I): Taxonomy & Decision Algorithms

Updated 2026-04-13

View full mapUser CodePanoramaGraph CaptureIR DesignOptimization PassesOperator Fusion8. Fusion TaxonomyYou are hereCode GenerationScheduling & ExecutionHardware Execution

Introduction: Why Fusion is the Most Important Optimization

Among all ML compiler optimization techniques, Operator Fusion is widely recognized as the most impactful. The PyTorch 2 paper (Ansel et al., 2024) explicitly states that over 80% of TorchInductor’s performance gains come from fusion optimizations.

Why is fusion so critical? The answer lies in the severe mismatch between compute throughput and memory bandwidth in modern GPUs. Consider the NVIDIA A100:

  • Peak compute: 312 TFLOPS (FP16 with Tensor Core)
  • Memory bandwidth: 2 TB/s (HBM2e)
  • Compute-to-bandwidth ratio: 156:1

This means for most operators (especially element-wise operations), the bottleneck is not computation but data movement. If an operator needs to read 12 MB from HBM (High Bandwidth Memory) and write back 12 MB, even if the computation takes only 1 microsecond, memory I/O alone requires about 12 microseconds (24 MB / 2 TB/s).

Roofline Model: Visualizing the Bottleneck

The Roofline Model (Williams et al., 2009) provides an intuitive framework. It defines Arithmetic Intensity (AI) as:

AI=FLOPsMemory I/O (Bytes)\text{AI} = \frac{\text{FLOPs}}{\text{Memory I/O (Bytes)}}

For an element-wise ReLU operator (y = max(0, x)):

  • Each element: 1 comparison + 1 conditional assignment ≈ 0 FLOPs (ignoring branch prediction)
  • Each element: read 4 bytes (FP32) + write 4 bytes = 8 bytes

Thus AI ≈ 0. Such operators are entirely memory-bound, leaving GPU compute units mostly idle.

Essence of fusion: Merge multiple operators into a single kernel, eliminating intermediate HBM round-trips, dramatically reducing memory traffic, increasing AI, and better utilizing hardware compute.

Fusion Taxonomy: Five Major Patterns

Based on fusion patterns and implementation complexity, operator fusion can be categorized into Simple Fusion and Complex Fusion. The interactive demo below showcases five typical fusion types.

Operator Fusion TaxonomySimple FusionElement-wise FusionReduction FusionBroadcast FusionTranspose/Reshape EliminationComplex FusionFlashAttention (Algorithmic Rewrite)BeforexReLUR 12W 12×R 12W 12+R 12W 12outputAfterxReLU+×+addR 12W 12outputElement-wise FusionElement-wise operator chain, eliminate intermediate buffersBefore: 3 kernel launches, (12+12) MB HBM R/W × 3 = 72 MB. After: 1 launch, 12+12 = 24 MB. 67% memory saved.Savings: 67% (7224 MB)

1. Element-wise Fusion

Pattern: Fuse multiple element-wise operations (relu, add, mul, tanh, etc.) into a single kernel.

Example: y = (relu(x) * alpha + beta)

Before optimization:

  1. t1 = relu(x) — 12 MB read + 12 MB write
  2. t2 = t1 * alpha — 12 MB read + 12 MB write
  3. y = t2 + beta — 12 MB read + 12 MB write
  • Total: 72 MB HBM I/O, 3 kernel launches

After optimization:

  • Single kernel computes y = relu(x) * alpha + beta directly
  • Total: 24 MB HBM I/O (12 MB read + 12 MB write)
  • Savings: 67% memory traffic

Implementation key points:

  • Each thread handles one element, performing all computations in registers without writing intermediate results back to HBM.
  • Heavily used by both TorchInductor and XLA.

2. Reduction Fusion

Pattern: Fuse reduction operators (sum, max, mean) with surrounding element-wise operations.

Example: L2 norm computation norm = sqrt(sum(x**2))

Before optimization:

  1. t1 = x**2 — 12 MB read + 12 MB write
  2. t2 = sum(t1) — 12 MB read + 4 bytes write
  3. norm = sqrt(t2) — 4 bytes read + 4 bytes write

After optimization:

  • Single-pass computation of L2 norm, writing only scalar result
  • Savings: 67% memory traffic

Implementation key points:

  • Leverage warp-level or block-level reduce primitives (e.g., CUDA’s __shfl_down_sync, __syncthreads).
  • Avoid materializing intermediate tensors.

3. Broadcast Fusion

Pattern: Fuse reduction + broadcast + element-wise, core pattern being reduce-then-apply.

Example: LayerNorm centering step y = x - mean(x)

Before optimization:

  1. mean_val = reduce_mean(x) — 12 MB read + 4 bytes write
  2. mean_broadcast = broadcast(mean_val) — creates 12 MB temporary tensor
  3. y = x - mean_broadcast — 24 MB read + 12 MB write

After optimization:

  • Each thread first computes global mean (shared via shared memory), then directly computes x[i] - mean
  • Savings: 60% memory traffic

Implementation key points:

  • On-the-fly broadcasting: mean stored in register or shared memory, reused by each thread.
  • Core optimization for LayerNorm, BatchNorm, RMSNorm, and other normalization operators.

4. Transpose/Reshape Elimination

Pattern: Eliminate explicit transpose or reshape operators via stride manipulation.

Example: y = matmul(reshape(transpose(x)), W)

Before optimization:

  1. t1 = transpose(x) — must materialize transposed result (12 MB)
  2. t2 = reshape(t1) — may require copy (12 MB)
  3. y = matmul(t2, W) — read 24 MB + write 12 MB

After optimization:

  • matmul kernel directly reads x following transpose + reshape strides
  • Savings: 71% memory traffic

Implementation key points:

  • Modern CUDA kernels (e.g., CUTLASS) support arbitrary stride input reads.
  • Compiler propagates layout information to consumers via stride propagation.

5. FlashAttention: Algorithmic Rewrite

The first four are pattern fusion, while FlashAttention (Dao et al., 2022) is algorithmic fusion, requiring computation reordering.

Standard Attention computation:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Problem: Intermediate matrix S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N} consumes significant memory.

  • For N=4096N = 4096, FP16: SS requires 32 MB
  • Often the GPU memory bottleneck in Transformers

FlashAttention strategy:

  1. Tiling: Load QQ, KK, VV blocks into SRAM (shared memory / L2 cache)
  2. Online Softmax: Use incremental softmax trick to complete softmax within tiles, avoiding full SS materialization
  3. I/O Complexity: Reduced from O(N2)O(N^2) to O ⁣(N2d2M)O\!\left(\frac{N^2 d^2}{M}\right), where MM is SRAM size

Performance gains: For long sequences (N>2048N > 2048), FlashAttention is 2-4× faster than standard implementation, with 10-20× lower memory footprint.

Fusion Legality Analysis: Five Rules

Not all operator pairs can be fused. Compilers must check these five conditions:

Fusion Legality CheckerClick two nodes to check fusion legalityx[128,768]W₁[768,3072]x@W₁[128,3072]GELU[128,3072]W₂[3072,768]GELU@W₂[128,768]dropout[128,768]+residual[128,768]LayerNorm[128,768]output[128,768]Click two nodes to check fusion legality

1. Producer-Consumer Relationship

Rule: Only adjacent operator pairs can be fused, where one operator’s output is another’s input.

Counter-example: mm1 and mm2 have no direct dependency and cannot fuse.

2. No Cycle

Rule: Fusion must not introduce cycles, violating the DAG (Directed Acyclic Graph) topological order.

Detection method: DFS or topological sort.

3. Shape Compatibility

Rule: Fused operators must share the same iteration domain.

Compatible scenarios:

  • Element-wise ops: identical shapes
  • Broadcast ops: alignable via broadcasting
  • Reduction ops: computable via tiling

Incompatible scenarios:

  • matmul output shape [M, N] vs. layer_norm input shape [B, S, D] cannot directly align (requires reshape)

4. No Side Effects

Rule: Fused operators must have no observable side effects.

Common side-effect operators:

  • dropout: depends on random number generator state
  • In-place operations: modify input tensors
  • I/O operations: print, log, checkpoint

Handling strategy:

  • In inference mode, dropout degenerates to identity and can fuse
  • Compiler must mark side effects to prevent cross-boundary fusion

5. Memory Fits in SRAM

Rule: Fused kernel’s total intermediate results must fit in SRAM (CUDA’s shared memory + register file).

A100 SRAM limits:

  • Shared memory per SM: 164 KB (dynamically configurable)
  • Register file per SM: 256 KB
  • Practical usable: ~48 KB (considering bank conflicts, occupancy)

Examples:

  • mm1 (12 KB) + gelu (2 KB) = 14 KB ✅
  • mm1 (12 KB) + mm2 (12 KB) = 24 KB ⚠️ (requires tiling)
  • mm1 (32 KB) + softmax (32 KB) = 64 KB ❌ (exceeds limit)

Fusion Decision Algorithms: Greedy vs. Graph Coloring

Compilers must automatically decide which operators to fuse. Two mainstream approaches:

1. Greedy Fusion — TorchInductor

Strategy: Traverse the computation graph in topological order, immediately fusing fusible edges.

Pseudocode:

def greedy_fusion(graph):
    groups = {node: {node} for node in graph.nodes}
    for edge in topological_order(graph.edges):
        producer, consumer = edge
        if can_fuse(producer, consumer):
            groups[consumer] = groups[producer] | groups[consumer]
    return groups

Advantages:

  • Simple and fast, O(E)O(E) complexity
  • Easy to implement and debug

Disadvantages:

  • Locally optimal, may miss globally better fusion plans
  • Sensitive to graph traversal order
Step 0
Fusion Algorithm Demo: Greedy Fusionx [128×768]LayerNormW₁x@W₁GELUW₂GELU@W₂+residualoutputInitInit: each node is its own group

2. Graph Coloring — XLA

Strategy: Model fusion as graph coloring problem:

  • Each operator is a node
  • Two operators that cannot fuse are connected by an edge
  • Goal: color all nodes with minimum colors (same color = same fusion group)

Optimization objectives:

  • Minimize memory traffic (estimated via heuristic functions)
  • Minimize kernel launch count (number of colors)

Advantages:

  • Theoretically finds globally better solutions (approximate solving within NP-hard framework)
  • Supports complex constraints (e.g., memory budget, occupancy)

Disadvantages:

  • Longer compile time (O(V2)O(V^2) to O(V3)O(V^3))
  • Complex implementation, difficult to debug

TorchInductor vs. XLA

DimensionTorchInductorXLA
AlgorithmGreedyGraph Coloring
Compile SpeedFast (< 100ms)Slow (> 1s)
Fusion QualityGood (covers 90% common patterns)Excellent (but limited improvement on rare patterns)
Use CaseEager mode, JITAOT, static graph

PyTorch 2 prioritizes compile speed because in Eager execution mode, compilation overhead directly impacts user experience. XLA targets AOT compilation and can afford longer compile times.

TorchInductor Fusion Implementation Details

Lowering Fusion Groups

Fused groups generate single kernels via Triton IR:

# Fusion group: relu → mul → add
@triton.jit
def fused_kernel(x_ptr, out_ptr, alpha, beta, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    
    # Load
    x = tl.load(x_ptr + offsets, mask=mask)
    
    # Fused computation (in registers)
    t1 = tl.maximum(x, 0.0)  # relu
    t2 = t1 * alpha          # mul
    out = t2 + beta          # add
    
    # Store
    tl.store(out_ptr + offsets, out, mask=mask)

Key points:

  • All intermediate results (t1, t2) stored in registers, no HBM writeback
  • Triton’s tl.load / tl.store automatically handle memory alignment and coalescing
  • BLOCK_SIZE is compile-time constant, Triton auto-unrolls loops

Scheduler: Determining Kernel Launch Order

Fused graphs may contain multiple independent fusion groups. Scheduler handles:

  1. Topological Sort: Ensure correct dependencies
  2. Memory Planning: Decide when to allocate/free intermediate buffers
  3. Concurrent Execution: Leverage CUDA streams for parallelism

TorchInductor uses a Dynamic Scheduler, adjusting dynamically at runtime based on actual tensor shapes.

Summary: Fusion is the Silver Bullet for Memory-Bound Problems

This article systematically introduced operator fusion’s taxonomy, legality analysis, and decision algorithms. Key takeaways:

  1. Why fusion matters most: Modern GPUs have compute-to-bandwidth ratios of 156:1; most operators are memory-bound
  2. Five fusion patterns: Element-wise, Reduction, Broadcast, Transpose, FlashAttention
  3. Five legality rules: Producer-consumer, No cycle, Shape compatible, No side effect, Memory fits
  4. Two decision algorithms: Greedy (TorchInductor, fast), Graph Coloring (XLA, optimal)

The next article will dive into Cost Models, exploring how compilers quantify fusion benefits and make optimal decisions in complex scenarios.

Further Reading

  • FlashAttention principles: Dao et al. (2022) paper details tiling and online softmax algorithms
  • XLA Fusion implementation: TensorFlow XLA’s HLO Fusion Pass source code (xla/service/gpu/gpu_fusible.cc)
  • Triton language: OpenAI’s Triton provides easier kernel authoring than CUDA
  • Roofline Model tools: NERSC’s Roofline Toolkit automatically analyzes operator AI and bottlenecks