Operator Fusion (Part II): Cost Models & Fusion in Practice
Updated 2026-04-13
Introduction
The previous article covered the WHAT (five fusion types) and WHEN (legality decision algorithms) of operator fusion. This article addresses two more critical questions:
- WHETHER — Is fusion beneficial? Not every legal fusion is worth performing.
- HOW — How is efficient fusion implemented in practice? From TorchInductor’s heuristics to FlashAttention’s algorithmic rewrite.
The core message: “can fuse” does not mean “should fuse”. Blindly fusing can increase register pressure, reduce occupancy, and cause compilation time to explode. Mature compilers need a cost model to make informed fusion decisions.
Cost Model Design
Why Not “Fuse Everything”?
Suppose we have a legal fusion candidate A+B. From the previous article, we know it passes all legality checks. But the fused kernel may actually be slower than running two separate kernels. Three reasons:
1. Register Pressure
Each GPU thread has a limited number of registers. An NVIDIA SM (Streaming Multiprocessor) typically has 65,536 32-bit registers. If a kernel requires 32 registers per thread with blockSize 256:
If fusion increases register demand from 32 to 48:
Blocks drop from 8 to 5, meaning 37.5% fewer warps can run concurrently on the SM.
2. Occupancy
Occupancy is a key GPU performance metric, defined as:
Three constraints affect occupancy:
- Register file size: As shown above, more registers per thread means fewer warps
- Shared memory capacity: V100 = 96 KB, A100 = 164 KB, H100 = 228 KB. If one block needs 32 KB shared memory, A100 can fit at most 5 blocks (164/32 ≈ 5)
- Max threads per block: Hardware limit (typically 1024)
Low occupancy means the GPU cannot effectively hide memory latency through warp switching, resulting in actual throughput far below peak.
3. Compilation Time
Triton and CUDA kernel compilation time grows roughly quadratically with kernel size. A fused large kernel might go from millisecond to second-level compilation, significantly impacting JIT compilation scenarios (e.g., torch.compile).
Formalizing the Roofline Model
The core of the cost model is the Roofline Model (Williams et al., 2009). For a kernel:
where:
is a monotonically increasing function reflecting occupancy’s impact on actual throughput. A common approximation is — throughput does not drop linearly below 25% occupancy because instruction-level parallelism still provides some utilization.
The interactive component below lets you adjust hardware parameters and observe the effects of different fusion decisions.
TorchInductor’s Cost Model in Practice
PyTorch 2’s TorchInductor (Ansel et al., 2024) uses a heuristic cost model — not exact modeling, but experience-based rules.
Core Fusion Heuristics
Pointwise fusion (almost always applied):
Inductor fuses consecutive pointwise operations (element-wise, broadcast) by default. The reasoning is simple: pointwise ops need no shared memory, register overhead is typically small, and eliminating intermediate tensor HBM round-trips is almost always a net positive.
Reduction + Pointwise fusion (conditional):
For a reduction (e.g., sum, max) followed by pointwise operations, Inductor checks:
- Whether the reduction dimension is small enough (won’t overflow shared memory)
- Whether post-fusion register usage stays manageable
If the reduction spans a large dimension (e.g., hidden_dim reduction over [batch, seq_len, hidden_dim]), fusion may cause excessive shared memory demand.
MatMul + Epilogue fusion (high value):
GEMM epilogue fusion is one of the most valuable fusion patterns. Before a GEMM tile writes back to HBM, bias add, activation (ReLU/GELU), and dropout are computed directly in registers. Both cuBLAS and CUTLASS natively support epilogue fusion; Triton achieves it through code generation.
Fusion Size Control
# Max nodes per fusion group
torch._inductor.config.max_fusion_size = 64 # default
# Max nodes in pointwise fusion
torch._inductor.config.max_pointwise_cat_size = 8
Debugging Fusion Decisions
Understanding the compiler’s fusion decisions is crucial during development:
import torch
# Method 1: Set trace environment variable
# TORCHINDUCTOR_TRACE=1 python my_script.py
# Method 2: Enable in code
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.graph_diagram = True # Generate pre/post fusion graphs
# Method 3: Inspect generated Triton kernels
torch._inductor.config.debug = True
# Generated kernel code in /tmp/torchinductor_<user>/
Examining trace output reveals fusion logs like:
[FUSION] fused pointwise: relu + mul + add → fused_kernel_0 (3 nodes)
[FUSION] skipped: layernorm + large_epilogue (register pressure: 52 > threshold 48)
MLIR-Level Fusion
MLIR (Multi-Level Intermediate Representation) provides a more principled approach to fusion than Inductor.
Linalg Dialect Fusion
MLIR’s Linalg dialect represents tensor operations as structured ops, naturally supporting producer-consumer fusion analysis. The core operation linalg.fuse_into_containing_op inlines producer computation into the consumer’s loop body.
Tile-and-Fuse: The Core Strategy
MLIR’s fusion strategy follows a key principle: tile first, then fuse.
- Tile the consumer: Split the consumer’s computation into tiles that fit the target memory level (e.g., L1 cache or shared memory)
- Fuse producer into tile: Inline the producer’s computation into the consumer’s tile loop
- Guarantee working set fits: Since tile size is determined by memory capacity, the fused working set naturally won’t overflow
This contrasts sharply with Inductor’s “fuse first, hope it doesn’t overflow” approach. Tile-and-fuse is a correctness-first method — it designs fusion from memory constraints, rather than checking them after the fact.
Affine Fusion
For perfect affine loop nests, MLIR’s affine dialect supports loop fusion based on polyhedral analysis. This approach can automatically discover optimal loop fusion orders and tile sizes, but only applies to scenarios with static shapes and affine indices.
Comparison Summary
| Dimension | TorchInductor | MLIR |
|---|---|---|
| Approach | Heuristic rules | Tile-and-fuse + polyhedral |
| Strengths | Fast, practical, covers common patterns | Principled, correctness guarantees |
| Weaknesses | May miss optimizations or make suboptimal decisions | Higher compile overhead, limited dynamic shape support |
| Best for | JIT compilation (torch.compile) | AOT compilation (deployment optimization) |
FlashAttention Deep Dive
FlashAttention (Dao et al., 2022) is one of the most impactful optimizations in ML systems. It is not generic fusion — it is a domain-specific algorithmic rewrite.
Standard Attention’s Memory Bottleneck
Standard Self-Attention computation flow:
where ( = sequence length, = head dimension).
The critical bottleneck is — an matrix. For , , FP16:
Total HBM access for standard attention:
- Read : MB
- Read : MB
- Write : MB → HBM
- Read : MB ← HBM (for softmax)
- Write softmax(): MB → HBM
- Read softmax(): MB ← HBM (multiply by )
- Read : MB
- Write : MB
Total HBM access: MB, with MB spent on the score matrix alone.
I/O complexity: . When (the usual case), the term dominates.
FlashAttention’s Core Idea
FlashAttention’s key insight: the full matrix never needs to be materialized.
Tile , , into blocks:
where is SRAM (shared memory) size in bytes, is head dimension, for FP16.
Algorithm outline:
For each Q tile (Br rows):
Initialize output accumulator O_tile = 0, running max m = -inf, running sum l = 0
For each K,V tile (Bc rows):
Load Q_tile, K_tile, V_tile from HBM to SRAM
Compute S_tile = Q_tile × K_tile^T in SRAM (size Br × Bc, fits entirely in SRAM!)
Compute local softmax: m_new, l_new, P_tile
Update O_tile with online softmax rescaling
Write O_tile back to HBM
At each iteration, only a sub-matrix exists in SRAM — never the full .
Online Softmax: The Key Trick
Softmax is fundamentally a global operation — it needs the maximum over an entire row. FlashAttention uses online softmax (Milakov & Gimelshein, 2018) to solve this:
For softmax of vector where , computation can be tiled:
- Process block 1: ,
- Process block 2: ,
- Continue, rescaling previous accumulations with each new max value
This computes the correct softmax without ever storing the full -dimensional vector.
I/O Complexity Analysis
FlashAttention’s HBM access:
Outer loop iterations, each:
- Load tile: bytes
- Inner loop iterations, each loads tile + tile: bytes
- Write tile: bytes
Total HBM access:
Since :
When (SRAM is large enough), this simplifies to — linear in sequence length !
Compared to standard attention’s , FlashAttention provides order-of-magnitude I/O efficiency gains on long sequences.
FlashAttention-2: Fewer Non-MatMul FLOPs
FlashAttention-2 (Dao, 2023) key improvements:
-
Move rescaling out of the inner loop: Reduces non-matmul FLOPs. On GPUs, Tensor Cores only accelerate matmul; other operations use CUDA cores, which are much slower. FA-1 rescales at every step, producing many non-matmul FLOPs. FA-2 defers rescaling to the end of the outer loop.
-
Better warp partitioning: FA-1 partitions warps across the dimension (each warp processes part of the head), requiring cross-warp reduction. FA-2 partitions across the dimension (each warp processes different K/V blocks), eliminating cross-warp communication.
Result: FA-2 achieves 50-73% of theoretical FLOPs utilization on A100, versus 25-40% for FA-1.
FlashAttention-3: Hopper Architecture Specialization
FlashAttention-3 (Shah et al., 2024) is deeply optimized for NVIDIA Hopper (H100):
- Warp-specialized pipeline: Exploits Hopper’s async execution — producer warps handle HBM→SRAM data movement while consumer warps compute, running as a pipeline.
- FP8 support: Hopper’s FP8 Tensor Cores deliver 2x FP16 throughput. FA-3 supports FP8 with incoherent processing to maintain numerical accuracy.
- Block quantization: Quantizes softmax intermediate results to FP8, reducing register and shared memory footprint.
Why FlashAttention is NOT “Operator Fusion”
It bears emphasizing: FlashAttention is not something a generic compiler can automatically discover. It requires:
- Understanding the mathematical properties of softmax (can be computed online)
- Designing a new algorithm (tiled attention with online softmax rescaling)
- Hand-optimizing memory access patterns
No generic pattern matching or cost model can automatically derive this algorithm from the operator graph. This is a domain-specific algorithmic rewrite — the intersection of compiler technology and domain knowledge.
Fusion Benchmarks in Practice
After the theoretical analysis, let us examine real-world data. The benchmark comparison below shows different fusion strategies across typical Transformer configurations (approximate values for educational purposes).
The four strategy levels:
- No Fusion: Each operator as a separate kernel, baseline
- Element-wise Only: Only pointwise fusion (GELU+dropout, bias+add, etc.)
- Full Inductor: TorchInductor’s complete fusion (including reduction fusion, epilogue fusion)
- Inductor + FlashAttn: Full Inductor plus FlashAttention
Analysis
Several key observations:
Universal benefit of element-wise fusion: Across all model sizes, element-wise fusion alone yields 1.8-2.0x throughput improvement. This is the “low-hanging fruit” — simple, safe, and with virtually no downside.
Value of Full Inductor: On top of element-wise fusion, Inductor’s reduction fusion and epilogue fusion contribute an additional 1.5-1.7x improvement. This is where the cost model becomes essential to avoid harmful fusions.
FlashAttention scales with sequence length:
- GPT-2 Small (seq=1024): FlashAttention adds 31% throughput
- LLaMA 7B (seq=2048): Adds 35% throughput
- LLaMA 70B (seq=4096): Adds 41% throughput
Longer sequences mean larger standard attention overhead, making FlashAttention’s linear I/O advantage more pronounced.
Sustained peak memory reduction: Fusion improves not just speed but also memory usage. From No Fusion to Inductor+FlashAttn, peak memory drops 37-47%. This is critical for large model training — saved memory can be used for larger batch sizes or longer sequences.
Summary
Core lessons from this article:
-
Cost models are essential. Fusion is not free — it can increase register pressure, reduce occupancy, and lengthen compile times. Mature compilers need quantitative analysis for fusion decisions.
-
Heuristic vs. principled approaches. TorchInductor uses fast heuristics (suited for JIT), MLIR uses tile-and-fuse (suited for AOT). They are complementary, not competing.
-
Algorithmic rewrites transcend generic fusion. FlashAttention demonstrates the power of domain knowledge — by understanding attention’s mathematical structure, it achieves optimizations no compiler could automatically discover. Future optimal systems will combine generic fusion (compiler) with domain-specific rewrites (libraries/algorithms).
-
The performance optimization hierarchy: Element-wise fusion (basic) → Cost-model-guided fusion (intermediate) → Algorithmic rewrite (expert). Each level provides irreplaceable value.
Further Reading
- FlashAttention paper trilogy: FA-1, FA-2, FA-3. Read sequentially to understand the evolution from core idea to Hopper specialization.
- Roofline Model original paper: Williams, Waterman & Patterson (2009), “Roofline: An Insightful Visual Performance Model for Multicore Architectures”. The classic framework for understanding compute-bound vs memory-bound.
- MLIR Linalg fusion docs: Official documentation. Implementation details for tile-and-fuse.
- PyTorch 2 paper: Ansel et al. (2024), ACM DL. Engineering implementation of TorchInductor’s fusion strategies.
- CUTLASS Epilogue Fusion: NVIDIA’s CUTLASS library provides templatized GEMM epilogue fusion — the best reference for understanding matmul+activation fusion.