Graph Optimization Passes (Part 2): Advanced Optimizations & Pattern Matching
Updated 2026-04-13
Introduction
In the previous article, we established the foundational framework for graph optimization passes: Pass Manager architecture, basic optimizations (DCE, CSE, constant folding), and dependency management between passes. These form the skeleton of modern compiler optimization systems.
But what truly makes a compiler “intelligent” are the advanced optimizations that understand and reshape program semantics. This article dives deep into three core topics:
- Layout Optimization — Choosing and propagating optimal memory layouts
- Pattern Matching — Declarative graph rewriting systems
- Memory Planning — Lifetime analysis and buffer reuse
These optimization techniques are key to ML compiler performance. A good layout decision can yield 2-3x speedup; precise pattern matching can fuse multiple operators into a single efficient kernel; intelligent memory planning can reduce peak memory usage by over 50%.
Layout Optimization: The Art of Data Layout
Why Layout Matters So Much
In deep learning, the same logical tensor can have multiple memory layout representations. For a 4D convolution feature map (Batch N, Channel C, Height H, Width W), the two most common layouts are:
- NCHW (Channel-first): Memory stores all of channel 0 first, then channel 1, and so on
- NHWC (Spatial-first): Memory stores all channels at position (0,0) first, then (0,1), and so on
Mathematically, these two layouts are completely equivalent — both represent the same 4D tensor. But performance differences can reach 2-5x. The reasons:
- Hardware access patterns: GPU Tensor Cores map convolutions to implicit GEMM; NHWC layout stores the channel dimension (K dimension) contiguously, enabling efficient Tensor Core fragment loading
- Cache utilization: NCHW stores spatial positions of the same channel contiguously, benefiting channel-wise operations; NHWC stores all channels at the same position contiguously, benefiting cross-channel operations
- Fusion opportunities: Different layouts affect which operators can be fused. Layout conversion has overhead; too many conversions can negate fusion benefits
Layout Propagation Algorithm
The compiler needs to choose a layout for each tensor. This isn’t a local problem — one tensor’s layout affects the layout choices of its producers and consumers. Therefore, layout optimization is fundamentally a global optimization problem.
The classic Layout Propagation algorithm has three phases:
Phase 1: Layout Constraint Inference
Traverse the compute graph and annotate each operation’s preference for input/output layouts:
def infer_layout_constraints(op):
if op.type == 'conv2d':
# cuDNN convolution has highly optimized NCHW implementation
return {'input': [NCHW], 'weight': [NCHW], 'output': [NCHW]}
elif op.type == 'matmul' and has_tensor_core():
# Tensor Core requires contiguous K dimension
return {'input': [NHWC], 'weight': [ANY], 'output': [NHWC]}
elif op.type == 'batch_norm':
# BatchNorm operates on channel dimension, NCHW is friendlier
return {'input': [NCHW], 'output': [NCHW]}
# ...
Phase 2: Cost-based Layout Selection
Model layout selection as a constraint satisfaction problem. Define the objective function:
- : Execution cost of operation under specific layout (can be obtained via profiling)
- : Cost of layout conversion (typically a memory-bound operation)
This is an NP-hard problem (similar to graph coloring). In practice, heuristic algorithms are used:
def select_layouts(graph):
# Initialize: choose optimal layout for each op (local optimum)
layouts = {op: op.preferred_layout() for op in graph.ops}
# Iterative optimization: greedily reduce transpose operations
changed = True
while changed:
changed = False
for op in graph.ops:
# Try changing op's layout to see if total cost decreases
for candidate_layout in op.compatible_layouts():
if cost_delta(op, candidate_layout) < 0:
layouts[op] = candidate_layout
changed = True
return layouts
Phase 3: Transpose Insertion
Insert explicit transpose operations on edges with inconsistent layouts:
for edge in graph.edges:
src_layout = layouts[edge.src]
dst_layout = layouts[edge.dst]
if src_layout != dst_layout:
insert_transpose(edge, src_layout, dst_layout)
Real-world Layout Optimization Case
For ResNet-50, in TensorRT’s optimization:
- Convolution layers use NCHW (cuDNN optimization)
- Fully-Connected layers use NHWC (after flattening to 2D, NHWC is equivalent to row-major)
- Insert a single transpose at Conv-FC boundary
This mixed layout strategy typically provides significant improvement over pure NCHW or pure NHWC (exact gains depend on model structure and hardware configuration).
Shape Inference and Specialization
Static vs Dynamic Shape
ML compilers need to know each tensor’s shape to perform aggressive optimizations (like loop unrolling, tiling parameter selection). But shape information availability varies greatly:
- Static shape: Fully known at compile time, like
tensor<128x768xf32> - Dynamic shape: Only rank is known, dimensions determined at runtime, like
tensor<?x?xf32>
Static shapes unlock numerous optimizations:
// Static shape: compiler can directly unroll loops, precompute indices
linalg.matmul ins(%A: tensor<128x768xf32>, %B: tensor<768x512xf32>)
outs(%C: tensor<128x512xf32>)
// Dynamic shape: requires runtime branches, indirect indexing
linalg.matmul ins(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>)
Shape Specialization with Guards
TorchInductor adopts a compromise: Shape Specialization with Guards.
- Assume specific shape values at compile time (e.g., batch=32, seq_len=512)
- Generate optimized static shape code
- Insert runtime guards: Check at code entry if shapes match assumptions
def optimized_forward(x, w):
# Guard: if shape doesn't match, fallback to generic path
if x.shape != (32, 512, 768):
return fallback_forward(x, w)
# Here is code optimized for (32, 512, 768)
# Loop bounds are constants, can be unrolled and vectorized
return specialized_matmul_32_512_768(x, w)
Guard overhead is typically small (a few integer comparisons), while optimization gains are huge. In PyTorch 2.0, this strategy makes dynamic shape model performance approach static shape.
Symbolic Shape Inference
For more complex dynamic shapes (like [B, S, H*D] where B and S are dynamic), the compiler needs symbolic shape inference:
# Input shapes
x: [B, S, D]
w: [D, 3*D]
# matmul output shape
qkv = x @ w # [B, S, 3*D]
# reshape
q, k, v = split(qkv, dim=-1, chunks=3) # each [B, S, D]
The compiler maintains a symbol table recording equality constraints between dimensions (like D_out = 3 * D_in) and updates constraints at each operation. This allows the compiler to derive relationships between dimensions even when shapes are dynamic, enabling optimizations (like determining if two matmuls can be fused).
Memory Planning: Lifetime and Reuse
Tensor Lifetime Analysis
The first step in memory planning is determining each tensor’s liveness: the interval from definition (produce) to last use (consume).
def compute_liveness(graph):
liveness = {}
for node in graph.topological_order():
for output in node.outputs:
liveness[output] = {
'birth': node.index,
'death': max(user.index for user in output.users)
}
return liveness
If two tensors’ lifetimes don’t overlap, they can share the same memory buffer.
In-place Mutation Detection
Some operations can modify inputs in-place, avoiding extra allocation:
# Non-in-place
y = x + 1 # Allocate new tensor y
# In-place (if x is not used afterward)
x.add_(1) # Directly modify x
The compiler needs to detect whether in-place modification is safe:
- Uniqueness check: Does the input tensor have other aliases? If
y = x; z = x + 1, cannot be in-place - Lifetime check: Is the input still used after this operation?
MLIR’s bufferization pass automatically handles these checks, lowering from tensor semantics to memref and inserting necessary copies.
Pool Allocation Strategy
For tensors that cannot be in-place, the compiler uses memory pools to reduce fragmentation and allocation overhead:
class MemoryPool:
def __init__(self):
self.free_buffers = [] # Free buffers sorted by size
self.allocated = {} # tensor -> buffer mapping
def allocate(self, tensor, size):
# First Fit: find smallest buffer that's large enough
for buf in self.free_buffers:
if buf.size >= size:
self.free_buffers.remove(buf)
self.allocated[tensor] = buf
return buf
# No suitable free buffer, allocate new one
buf = new_buffer(size)
self.allocated[tensor] = buf
return buf
def free(self, tensor):
buf = self.allocated.pop(tensor)
self.free_buffers.append(buf)
self.free_buffers.sort(key=lambda b: b.size)
Advanced strategies like Graph Coloring can further optimize: view tensors as graph nodes, connect edges between tensors with overlapping lifetimes, the problem becomes graph coloring — tensors with the same color share buffers.
In practice, PyTorch Inductor uses a mixed strategy:
- Small tensors (< 1MB): Use pool allocation
- Large tensors (> 1MB): Direct allocation and deallocation (avoid pool fragmentation)
- Long-lifetime tensors: Don’t enter pool (like model weights)
Deep Dive into Pattern Matching
The Challenge of Graph Rewriting
The core of graph optimization is identifying patterns and replacing them. For example, identifying MatMul + BiasAdd and replacing with FusedLinear. The naive approach is hand-written matching code:
def fuse_linear(graph):
for node in graph.nodes:
if node.op == 'add' and node.inputs[0].op == 'matmul':
matmul = node.inputs[0]
bias = node.inputs[1]
if is_param(bias):
# Match successful, replace
fused = create_fused_linear(matmul.inputs, bias)
replace_node(node, fused)
Problems with this approach:
- Verbose code: Each pattern requires writing matching logic
- Hard to maintain: Adding new patterns requires understanding the entire matching framework
- Inefficient: Each graph traversal checks all patterns
MLIR DRR: Declarative Rewrite Rules
MLIR’s Declarative Rewrite Rules (DRR) provides a declarative pattern definition language:
// File: FusionPatterns.td
def FuseLinearPattern : Pat<
// Pattern to match: add(matmul(x, w), bias)
(AddOp
(MatMulOp $x, $w),
$bias
),
// Replacement: FusedLinear(x, w, bias)
(FusedLinearOp $x, $w, $bias),
// Constraint: bias must be 1D
[(IsOneDim $bias)]
>;
DRR automatically generates C++ matching code from this declaration, including:
- Topological traversal logic
- Constraint checking
- Node replacement and edge updates
DRR’s core advantage is composability: multiple small patterns can compose into complex patterns without rewriting matching logic.
PDL: Pattern Description Language
MLIR also provides the more general PDL (Pattern Description Language), which can express complex patterns that DRR cannot:
pdl.pattern @FuseConvBNPattern : benefit(2) {
%input = pdl.operand
%conv_weight = pdl.operand
%conv = pdl.operation "linalg.conv_2d"(%input, %conv_weight : !pdl.value, !pdl.value)
%conv_result = pdl.result 0 of %conv
%bn_scale = pdl.operand
%bn_bias = pdl.operand
%bn = pdl.operation "linalg.batch_norm"(%conv_result, %bn_scale, %bn_bias : !pdl.value, !pdl.value, !pdl.value)
%bn_result = pdl.result 0 of %bn
pdl.rewrite %bn {
%fused = pdl.operation "linalg.fused_conv_bn"(%input, %conv_weight, %bn_scale, %bn_bias)
%fused_result = pdl.result 0 of %fused
pdl.replace %bn with %fused_result
}
}
PDL is Turing-complete (can express arbitrary constraints), but the cost is more complex syntax and longer compilation time.
torch.fx Subgraph Rewriting
PyTorch’s torch.fx provides Python-native pattern matching:
from torch.fx import symbolic_trace, replace_pattern
def pattern(x, w, bias):
mm = torch.matmul(x, w)
return torch.add(mm, bias)
def replacement(x, w, bias):
return torch.nn.functional.linear(x, w.T, bias)
# Apply replacement to model
model = symbolic_trace(model)
replace_pattern(model, pattern, replacement)
FX’s replace_pattern internally uses subgraph isomorphism algorithms to match patterns. While slower than DRR, it’s Python-user-friendly and can leverage Python control flow (like using if statements in patterns).
Complete Pass Pipeline for Transformer Attention
Now let’s integrate these optimization techniques and examine a real example: the optimization pipeline for Transformer self-attention.
Pipeline Stage Analysis
Stage 1: Original Graph
This is the initial representation after FX capture. Three independent Q/K/V projections (matmul), one softmax, one output projection. 14 nodes, no optimizations.
Stage 2: Constant Folding
The compiler recognizes that is a compile-time constant (, so ), directly folds it to 0.125. Also converts div to mul (mul is about 2x faster than div on GPU).
Stage 3: QKV Projection Fusion
This is the most critical step. Three independent matmuls:
can be fused into one large matmul:
then split into three parts. Why is this faster?
- Reduced kernel launch overhead: 3 kernel launches become 1
- Reduced memory reads: is read from HBM only once (instead of three times)
- Improved Tensor Core utilization: Large matrix multiplications saturate hardware more easily
This step reduces HBM access from 100% to 75%.
Stage 4: Layout Optimization
The compiler annotates layouts for each tensor:
x:[B, S, H*D](batch, sequence, hidden)Q/K/V:[B, H, S, D](batch, heads, sequence, head_dim) — convert to multi-head formscores:[B, H, S, S](attention scores)
These layout annotations guide subsequent lowering: reshape operations are lowered to views (zero cost) rather than copies.
Stage 5: Memory Planning
The compiler analyzes lifetimes:
QKVcan be freed after splitQ/K/VshareQKV’s buffer (via view)scoresandsoftmaxupdate in-place (scoresis not used after softmax)attnreusesQKV’s buffer (lifetimes don’t overlap)outreuses inputx’s buffer (ifxis not used afterward)
Final HBM access drops to 60%, peak memory usage reduced by about 40%.
Real Performance Comparison
On A100 GPU, for typical GPT-3 attention (B=1, H=96, S=2048, D=128), below are educational estimates based on the Roofline model (actual latencies vary with driver versions, kernel implementations, etc.):
| Stage | Latency (ms) | Relative Speedup |
|---|---|---|
| Original (PyTorch eager) | 2.8 | 1.0x |
| + Constant folding | 2.75 | 1.02x |
| + QKV fusion | 1.9 | 1.47x |
| + Layout optimization | 1.65 | 1.70x |
| + Memory planning | 1.55 | 1.81x |
| FlashAttention-2 (SOTA) | 0.95 | 2.95x |
As shown, even these “basic” optimizations yield 1.8x speedup. FlashAttention-2, through more aggressive tiling and on-chip memory utilization, achieves 3x speedup.
Summary
This article explored three major advanced graph optimization techniques:
Layout Optimization chooses optimal data layouts to match hardware access patterns with operator characteristics. The core challenge is global optimization — one tensor’s layout affects its neighbors, requiring cost-based algorithms to balance operation costs and conversion costs.
Pattern Matching through declarative graph rewriting systems enables compilers to identify and replace subgraph patterns. MLIR’s DRR and PDL provide powerful pattern description capabilities, while PyTorch FX’s replace_pattern offers a Python-native interface.
Memory Planning through lifetime analysis and buffer reuse significantly reduces memory footprint and allocation overhead. Core techniques include liveness analysis, in-place mutation detection, and pool allocation strategies.
These three work together to form the core optimization engine of modern ML compilers. In the next article, we’ll explore Code Generation & Scheduling — how to translate optimized high-level graphs into efficient GPU kernels.
Further Reading
- MLIR Declarative Rewrite Rules documentation — Detailed DRR syntax and semantics
- MLIR PDL documentation — Complete reference for Pattern Description Language
- PyTorch
torch.fx.replace_patternsource code — Understanding Python-native pattern matching implementation - “Tensor Comprehensions” paper — Facebook’s early tensor compiler, exploring joint optimization of layout and schedule
- NVIDIA Tensor Core Programming Guide — Understanding Tensor Core layout requirements and how to write efficient WMMA code