Content on this site is AI-generated and may contain errors. If you find issues, please report at GitHub Issues .

Operator Fusion (Part II): Cost Models & Fusion in Practice

Operator Fusion (Part II): Cost Models & Fusion in Practice

Updated 2026-04-13

View full mapUser CodePanoramaGraph CaptureIR DesignOptimization PassesOperator Fusion9. Cost ModelYou are hereCode GenerationScheduling & ExecutionHardware Execution

Introduction

The previous article covered the WHAT (five fusion types) and WHEN (legality decision algorithms) of operator fusion. This article addresses two more critical questions:

  1. WHETHER — Is fusion beneficial? Not every legal fusion is worth performing.
  2. HOW — How is efficient fusion implemented in practice? From TorchInductor’s heuristics to FlashAttention’s algorithmic rewrite.

The core message: “can fuse” does not mean “should fuse”. Blindly fusing can increase register pressure, reduce occupancy, and cause compilation time to explode. Mature compilers need a cost model to make informed fusion decisions.

Cost Model Design

Why Not “Fuse Everything”?

Suppose we have a legal fusion candidate A+B. From the previous article, we know it passes all legality checks. But the fused kernel may actually be slower than running two separate kernels. Three reasons:

1. Register Pressure

Each GPU thread has a limited number of registers. An NVIDIA SM (Streaming Multiprocessor) typically has 65,536 32-bit registers. If a kernel requires 32 registers per thread with blockSize 256:

Registers per block=32×256=8192\text{Registers per block} = 32 \times 256 = 8192 Max blocks per SM=65536/8192=8\text{Max blocks per SM} = \lfloor 65536 / 8192 \rfloor = 8

If fusion increases register demand from 32 to 48:

Registers per block=48×256=12288\text{Registers per block} = 48 \times 256 = 12288 Max blocks per SM=65536/12288=5\text{Max blocks per SM} = \lfloor 65536 / 12288 \rfloor = 5

Blocks drop from 8 to 5, meaning 37.5% fewer warps can run concurrently on the SM.

2. Occupancy

Occupancy is a key GPU performance metric, defined as:

Occupancy=Active warps per SMMax warps per SM\text{Occupancy} = \frac{\text{Active warps per SM}}{\text{Max warps per SM}}

Three constraints affect occupancy:

  • Register file size: As shown above, more registers per thread means fewer warps
  • Shared memory capacity: V100 = 96 KB, A100 = 164 KB, H100 = 228 KB. If one block needs 32 KB shared memory, A100 can fit at most 5 blocks (164/32 ≈ 5)
  • Max threads per block: Hardware limit (typically 1024)

Low occupancy means the GPU cannot effectively hide memory latency through warp switching, resulting in actual throughput far below peak.

3. Compilation Time

Triton and CUDA kernel compilation time grows roughly quadratically with kernel size. A fused large kernel might go from millisecond to second-level compilation, significantly impacting JIT compilation scenarios (e.g., torch.compile).

Formalizing the Roofline Model

The core of the cost model is the Roofline Model (Williams et al., 2009). For a kernel:

Texec=max(Tcompute,Tmemory)T_{\text{exec}} = \max(T_{\text{compute}}, T_{\text{memory}})

where:

Tcompute=FLOPsPeak FLOPS×f(occupancy)T_{\text{compute}} = \frac{\text{FLOPs}}{\text{Peak FLOPS} \times f(\text{occupancy})} Tmemory=HBM bytesHBM bandwidthT_{\text{memory}} = \frac{\text{HBM bytes}}{\text{HBM bandwidth}}

f(occupancy)f(\text{occupancy}) is a monotonically increasing function reflecting occupancy’s impact on actual throughput. A common approximation is f(o)=max(0.25,o)f(o) = \max(0.25, o) — throughput does not drop linearly below 25% occupancy because instruction-level parallelism still provides some utilization.

The interactive component below lets you adjust hardware parameters and observe the effects of different fusion decisions.

Cost Model CalculatorSelect GPU:V100A100H100Bandwidth: 2 TB/sCompute: 312 TFLOPS | Shared Mem: 164 KBGELU + Dropout (Fusion Beneficial)Large Reduction + Small Pointwise (Fusion Harmful)MatMul + BiasAdd + ReLU (Tradeoff)UnfusedGELUFLOPs (M):2HBM Read (MB):4HBM Write (MB):4Reg/Thread:16Shared Mem:Occupancy:100%Est. Time:4.00 μsMemory-boundDropoutFLOPs (M):1HBM Read (MB):4HBM Write (MB):4Reg/Thread:12Shared Mem:Occupancy:100%Est. Time:4.00 μsMemory-boundFusedGELU+DropoutFLOPs (M):3HBM Read (MB):4HBM Write (MB):4Reg/Thread:24Shared Mem:Occupancy:100%Est. Time:4.00 μsMemory-boundEst. TimeUnfused8.00 μsFused4.00 μsOccupancyUnfused100%Fused100%Total HBMUnfused16 MBFused8 MBVerdict Fusion Beneficial100% fasterTwo memory-bound pointwise ops. Fusion eliminates intermediate tensor HBM read+write (8 MB). FLOPs unchanged. Always beneficial.

TorchInductor’s Cost Model in Practice

PyTorch 2’s TorchInductor (Ansel et al., 2024) uses a heuristic cost model — not exact modeling, but experience-based rules.

Core Fusion Heuristics

Pointwise fusion (almost always applied):

Inductor fuses consecutive pointwise operations (element-wise, broadcast) by default. The reasoning is simple: pointwise ops need no shared memory, register overhead is typically small, and eliminating intermediate tensor HBM round-trips is almost always a net positive.

Reduction + Pointwise fusion (conditional):

For a reduction (e.g., sum, max) followed by pointwise operations, Inductor checks:

  • Whether the reduction dimension is small enough (won’t overflow shared memory)
  • Whether post-fusion register usage stays manageable

If the reduction spans a large dimension (e.g., hidden_dim reduction over [batch, seq_len, hidden_dim]), fusion may cause excessive shared memory demand.

MatMul + Epilogue fusion (high value):

GEMM epilogue fusion is one of the most valuable fusion patterns. Before a GEMM tile writes back to HBM, bias add, activation (ReLU/GELU), and dropout are computed directly in registers. Both cuBLAS and CUTLASS natively support epilogue fusion; Triton achieves it through code generation.

Fusion Size Control

# Max nodes per fusion group
torch._inductor.config.max_fusion_size = 64  # default

# Max nodes in pointwise fusion
torch._inductor.config.max_pointwise_cat_size = 8

Debugging Fusion Decisions

Understanding the compiler’s fusion decisions is crucial during development:

import torch

# Method 1: Set trace environment variable
# TORCHINDUCTOR_TRACE=1 python my_script.py

# Method 2: Enable in code
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.graph_diagram = True  # Generate pre/post fusion graphs

# Method 3: Inspect generated Triton kernels
torch._inductor.config.debug = True
# Generated kernel code in /tmp/torchinductor_<user>/

Examining trace output reveals fusion logs like:

[FUSION] fused pointwise: relu + mul + add → fused_kernel_0 (3 nodes)
[FUSION] skipped: layernorm + large_epilogue (register pressure: 52 > threshold 48)

MLIR-Level Fusion

MLIR (Multi-Level Intermediate Representation) provides a more principled approach to fusion than Inductor.

Linalg Dialect Fusion

MLIR’s Linalg dialect represents tensor operations as structured ops, naturally supporting producer-consumer fusion analysis. The core operation linalg.fuse_into_containing_op inlines producer computation into the consumer’s loop body.

Tile-and-Fuse: The Core Strategy

MLIR’s fusion strategy follows a key principle: tile first, then fuse.

  1. Tile the consumer: Split the consumer’s computation into tiles that fit the target memory level (e.g., L1 cache or shared memory)
  2. Fuse producer into tile: Inline the producer’s computation into the consumer’s tile loop
  3. Guarantee working set fits: Since tile size is determined by memory capacity, the fused working set naturally won’t overflow

This contrasts sharply with Inductor’s “fuse first, hope it doesn’t overflow” approach. Tile-and-fuse is a correctness-first method — it designs fusion from memory constraints, rather than checking them after the fact.

Affine Fusion

For perfect affine loop nests, MLIR’s affine dialect supports loop fusion based on polyhedral analysis. This approach can automatically discover optimal loop fusion orders and tile sizes, but only applies to scenarios with static shapes and affine indices.

Comparison Summary

DimensionTorchInductorMLIR
ApproachHeuristic rulesTile-and-fuse + polyhedral
StrengthsFast, practical, covers common patternsPrincipled, correctness guarantees
WeaknessesMay miss optimizations or make suboptimal decisionsHigher compile overhead, limited dynamic shape support
Best forJIT compilation (torch.compile)AOT compilation (deployment optimization)

FlashAttention Deep Dive

FlashAttention (Dao et al., 2022) is one of the most impactful optimizations in ML systems. It is not generic fusion — it is a domain-specific algorithmic rewrite.

Standard Attention’s Memory Bottleneck

Standard Self-Attention computation flow:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

where Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d} (NN = sequence length, dd = head dimension).

The critical bottleneck is S=QKTS = QK^T — an N×NN \times N matrix. For N=4096N = 4096, d=64d = 64, FP16:

S=N2×2 bytes=40962×2=33,554,432 bytes32 MB|S| = N^2 \times 2 \text{ bytes} = 4096^2 \times 2 = 33{,}554{,}432 \text{ bytes} \approx 32 \text{ MB}

Total HBM access for standard attention:

  1. Read QQ: N×d×2=0.5N \times d \times 2 = 0.5 MB
  2. Read KK: N×d×2=0.5N \times d \times 2 = 0.5 MB
  3. Write SS: N2×2=32N^2 \times 2 = 32 MB → HBM
  4. Read SS: 3232 MB ← HBM (for softmax)
  5. Write softmax(SS): 3232 MB → HBM
  6. Read softmax(SS): 3232 MB ← HBM (multiply by VV)
  7. Read VV: 0.50.5 MB
  8. Write OO: 0.50.5 MB

Total HBM access: 130\approx 130 MB, with 128128 MB spent on the N×NN \times N score matrix alone.

I/O complexity: Θ(Nd+N2)\Theta(Nd + N^2). When NdN \gg d (the usual case), the N2N^2 term dominates.

FlashAttention’s Core Idea

FlashAttention’s key insight: the full N×NN \times N matrix never needs to be materialized.

Tile QQ, KK, VV into blocks:

Br=Bc=M4d2B_r = B_c = \left\lfloor \frac{M}{4d \cdot 2} \right\rfloor

where MM is SRAM (shared memory) size in bytes, dd is head dimension, ×2\times 2 for FP16.

Algorithm outline:

For each Q tile (Br rows):
    Initialize output accumulator O_tile = 0, running max m = -inf, running sum l = 0
    For each K,V tile (Bc rows):
        Load Q_tile, K_tile, V_tile from HBM to SRAM
        Compute S_tile = Q_tile × K_tile^T in SRAM (size Br × Bc, fits entirely in SRAM!)
        Compute local softmax: m_new, l_new, P_tile
        Update O_tile with online softmax rescaling
    Write O_tile back to HBM

At each iteration, only a Br×BcB_r \times B_c sub-matrix exists in SRAM — never the full N×NN \times N.

Online Softmax: The Key Trick

Softmax is fundamentally a global operation — it needs the maximum over an entire row. FlashAttention uses online softmax (Milakov & Gimelshein, 2018) to solve this:

For softmax of vector x=[x1,x2,,xN]x = [x_1, x_2, \ldots, x_N] where σ(x)i=exijexj\sigma(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}, computation can be tiled:

  1. Process block 1: m1=max(x1:B)m_1 = \max(x_{1:B}), l1=i=1Bexim1l_1 = \sum_{i=1}^{B} e^{x_i - m_1}
  2. Process block 2: m2=max(m1,max(xB+1:2B))m_2 = \max(m_1, \max(x_{B+1:2B})), l2=l1em1m2+i=B+12Bexim2l_2 = l_1 \cdot e^{m_1 - m_2} + \sum_{i=B+1}^{2B} e^{x_i - m_2}
  3. Continue, rescaling previous accumulations with each new max value

This computes the correct softmax without ever storing the full NN-dimensional vector.

I/O Complexity Analysis

FlashAttention’s HBM access:

Outer loop N/Br\lceil N / B_r \rceil iterations, each:

  • Load QQ tile: Br×d×2B_r \times d \times 2 bytes
  • Inner loop N/Bc\lceil N / B_c \rceil iterations, each loads KK tile + VV tile: 2×Bc×d×22 \times B_c \times d \times 2 bytes
  • Write OO tile: Br×d×2B_r \times d \times 2 bytes

Total HBM access:

HBM=O(NBr(Brd+NBc2Bcd+Brd))=O(N2dBc+Nd)\text{HBM} = O\left(\frac{N}{B_r} \cdot \left(B_r d + \frac{N}{B_c} \cdot 2B_c d + B_r d\right)\right) = O\left(\frac{N^2 d}{B_c} + Nd\right)

Since Bc=Θ(M/d)B_c = \Theta(M / d):

HBM=O(N2d2M+Nd)\text{HBM} = O\left(\frac{N^2 d^2}{M} + Nd\right)

When M=Θ(Nd)M = \Theta(Nd) (SRAM is large enough), this simplifies to O(Nd)O(Nd)linear in sequence length NN!

Compared to standard attention’s Θ(Nd+N2)\Theta(Nd + N^2), FlashAttention provides order-of-magnitude I/O efficiency gains on long sequences.

FlashAttention-2: Fewer Non-MatMul FLOPs

FlashAttention-2 (Dao, 2023) key improvements:

  1. Move rescaling out of the inner loop: Reduces non-matmul FLOPs. On GPUs, Tensor Cores only accelerate matmul; other operations use CUDA cores, which are much slower. FA-1 rescales at every step, producing many non-matmul FLOPs. FA-2 defers rescaling to the end of the outer loop.

  2. Better warp partitioning: FA-1 partitions warps across the dd dimension (each warp processes part of the head), requiring cross-warp reduction. FA-2 partitions across the NN dimension (each warp processes different K/V blocks), eliminating cross-warp communication.

Result: FA-2 achieves 50-73% of theoretical FLOPs utilization on A100, versus 25-40% for FA-1.

FlashAttention-3: Hopper Architecture Specialization

FlashAttention-3 (Shah et al., 2024) is deeply optimized for NVIDIA Hopper (H100):

  1. Warp-specialized pipeline: Exploits Hopper’s async execution — producer warps handle HBM→SRAM data movement while consumer warps compute, running as a pipeline.
  2. FP8 support: Hopper’s FP8 Tensor Cores deliver 2x FP16 throughput. FA-3 supports FP8 with incoherent processing to maintain numerical accuracy.
  3. Block quantization: Quantizes softmax intermediate results to FP8, reducing register and shared memory footprint.

Why FlashAttention is NOT “Operator Fusion”

It bears emphasizing: FlashAttention is not something a generic compiler can automatically discover. It requires:

  • Understanding the mathematical properties of softmax (can be computed online)
  • Designing a new algorithm (tiled attention with online softmax rescaling)
  • Hand-optimizing memory access patterns

No generic pattern matching or cost model can automatically derive this algorithm from the softmax(QKT)V\text{softmax}(QK^T)V operator graph. This is a domain-specific algorithmic rewrite — the intersection of compiler technology and domain knowledge.

FlashAttention Deep DiveSequence Length N:5121024204840968192SRAM Size:164 KB48KB228KBTile Size: 328 rowsPauseStandard AttentionScore Matrix S = QK^T [2048×2048] = 8.0 MBEntire N×N matrix in HBMK (N=2048)Q (N=2048)Memory Access Flow:1Read Q2Read K3Write S to HBM4Read S5Write softmax(S)6Read V7Write OFlashAttentionScore Matrix S = QK^T [2048×2048] — Only tile in SRAMSRAM tile: 328x328K (N=2048)Q (N=2048)Memory Access Flow:1Load Q tile2Load K,V tile3Compute in SRAM4Update accum.5Write O tile+ online softmax computed in SRAMHBM AccessStandard Attention33.0 MBFlashAttention4.5 MBI/O ComplexityStandard Attention: O(Nd + N²)FlashAttention: O(N²d²/M)M = SRAM = 164 KB, d = 64Savings86% (7.4x)N=2048, d=64: Standard S matrix 8.0 MB. Longer sequences → larger FlashAttention advantage.

Fusion Benchmarks in Practice

After the theoretical analysis, let us examine real-world data. The benchmark comparison below shows different fusion strategies across typical Transformer configurations (approximate values for educational purposes).

The four strategy levels:

  1. No Fusion: Each operator as a separate kernel, baseline
  2. Element-wise Only: Only pointwise fusion (GELU+dropout, bias+add, etc.)
  3. Full Inductor: TorchInductor’s complete fusion (including reduction fusion, epilogue fusion)
  4. Inductor + FlashAttn: Full Inductor plus FlashAttention
Fusion Strategy BenchmarksGPT-2 Smallseq=1024 h=768 B=16LLaMA 7Bseq=2048 h=4096 B=1LLaMA 70Bseq=4096 h=8192 B=1Throughput (TFLOPS)Latency (ms)Peak Memory (MB)HBM Access (GB)4585130175422214.510.84.8G3.8G3.2G2.4G6.23.82.51.8No FusionElement-wise OnlyFull InductorInductor + FlashAttnKey InsightFull Inductor provides the largest single-step gain (85→130 TFLOPS). FlashAttention adds 35%throughput and reduces 25% peak memory.Approximate values for educational purposes

Analysis

Several key observations:

Universal benefit of element-wise fusion: Across all model sizes, element-wise fusion alone yields 1.8-2.0x throughput improvement. This is the “low-hanging fruit” — simple, safe, and with virtually no downside.

Value of Full Inductor: On top of element-wise fusion, Inductor’s reduction fusion and epilogue fusion contribute an additional 1.5-1.7x improvement. This is where the cost model becomes essential to avoid harmful fusions.

FlashAttention scales with sequence length:

  • GPT-2 Small (seq=1024): FlashAttention adds 31% throughput
  • LLaMA 7B (seq=2048): Adds 35% throughput
  • LLaMA 70B (seq=4096): Adds 41% throughput

Longer sequences mean larger N2N^2 standard attention overhead, making FlashAttention’s linear I/O advantage more pronounced.

Sustained peak memory reduction: Fusion improves not just speed but also memory usage. From No Fusion to Inductor+FlashAttn, peak memory drops 37-47%. This is critical for large model training — saved memory can be used for larger batch sizes or longer sequences.

Summary

Core lessons from this article:

  1. Cost models are essential. Fusion is not free — it can increase register pressure, reduce occupancy, and lengthen compile times. Mature compilers need quantitative analysis for fusion decisions.

  2. Heuristic vs. principled approaches. TorchInductor uses fast heuristics (suited for JIT), MLIR uses tile-and-fuse (suited for AOT). They are complementary, not competing.

  3. Algorithmic rewrites transcend generic fusion. FlashAttention demonstrates the power of domain knowledge — by understanding attention’s mathematical structure, it achieves optimizations no compiler could automatically discover. Future optimal systems will combine generic fusion (compiler) with domain-specific rewrites (libraries/algorithms).

  4. The performance optimization hierarchy: Element-wise fusion (basic) → Cost-model-guided fusion (intermediate) → Algorithmic rewrite (expert). Each level provides irreplaceable value.

Further Reading

  • FlashAttention paper trilogy: FA-1, FA-2, FA-3. Read sequentially to understand the evolution from core idea to Hopper specialization.
  • Roofline Model original paper: Williams, Waterman & Patterson (2009), “Roofline: An Insightful Visual Performance Model for Multicore Architectures”. The classic framework for understanding compute-bound vs memory-bound.
  • MLIR Linalg fusion docs: Official documentation. Implementation details for tile-and-fuse.
  • PyTorch 2 paper: Ansel et al. (2024), ACM DL. Engineering implementation of TorchInductor’s fusion strategies.
  • CUTLASS Epilogue Fusion: NVIDIA’s CUTLASS library provides templatized GEMM epilogue fusion — the best reference for understanding matmul+activation fusion.