Dynamic Shapes: The Full-Pipeline Challenge from Capture to Execution
Updated 2026-04-13
Introduction
Dynamic shapes represent the most critical practical challenge in ML compilation. In previous articles, we discussed operator fusion, cost models, and tiling strategies — all of these optimization techniques share an implicit assumption: the compiler knows the full tensor shape at compile time. However, in real-world LLM inference scenarios, this assumption almost never holds.
The core contradiction:
- Compilers need static information for optimization: constant folding requires known values, fusion needs intermediate result sizes, tiling needs matrix dimensions
- ML workloads are inherently dynamic: LLM inference processes requests with different sequence lengths, continuous batching varies batch size at runtime
This is a vertical thematic article that traces how dynamic shapes impact every stage of the compiler stack — from graph capture, IR representation, optimization passes, operator fusion, tiling, to code generation. We use PyTorch 2’s torch.compile as our primary example, complemented by MLIR’s dynamic tensor support, to comprehensively analyze the problems and solutions.
Problem Definition
Dynamic Shapes in LLM Inference
Consider a typical LLM inference scenario. A GPT model deployed in production receives requests from different users:
# User A sends a short text
input_A = tokenizer("Hello world") # shape: [1, 3, 768]
# User B sends a longer text
input_B = tokenizer("In this comprehensive guide, we will explore ...") # shape: [1, 256, 768]
# User C sends a very long text
input_C = tokenizer("...") # shape: [1, 2048, 768]
The three requests have seq_len of 3, 256, and 2048 respectively. If the compiler produces a highly specialized kernel for shape [1, 3, 768], that kernel cannot handle the other two requests.
Under continuous batching, the situation is even more complex:
- Batch size changes dynamically at runtime (new request arrives → batch grows; request completes → batch shrinks)
- Even within the same batch, different sequences have different effective lengths (controlled by attention masks)
- KV cache length grows incrementally with each decoding step
This means a Transformer model has at least two dynamic dimensions: batch_size and seq_len.
Traditional Compilers vs ML Compilers
Traditional compilers (GCC, LLVM) process programs where types and data structure sizes are typically known at compile time (or can be inferred through static analysis). Even with dynamic memory allocation (malloc), the compiler doesn’t need to choose different optimization strategies based on allocation sizes.
ML compilers face a fundamentally different challenge: the shape of data directly influences optimization strategy selection. For matrix multiplication:
If is dynamic:
- Tiling: cannot select optimal
BLOCK_Mat compile time since the total number of tiles is unknown - Fusion: cannot precisely determine whether intermediate results fit in SRAM
- Codegen: cannot eliminate bounds checks because the last tile may need masking
- Memory planning: cannot pre-allocate buffers of exact size at compile time
This is the essence of the dynamic shapes problem: shape is not just a data attribute — it is an input to optimization decisions.
Quantifying the Gap: Static vs Dynamic Performance
In a typical Transformer layer, the performance gap between static-shape compilation and dynamic-shape compilation ranges from approximately 10%–40%, depending on:
| Factor | Impact |
|---|---|
| Matrix dimensions | Small matrices are more affected (overhead is proportionally larger) |
| Number of dynamic dims | 1 dynamic dim < 2 < fully dynamic |
| Kernel type | Compute-bound kernels: lower impact; memory-bound: higher impact |
| Compiler maturity | Better symbolic reasoning narrows the gap |
PyTorch’s Solution — Symbolic Shapes
PyTorch 2, through TorchDynamo + AOTAutograd, introduces a comprehensive symbolic shape system to address the dynamic shapes challenge.
SymInt / SymFloat
SymInt is PyTorch’s core abstraction. It represents a symbolic integer — not a concrete value (like 128), but a symbolic variable (like s0), meaning “some value determined at runtime.”
import torch
def my_model(x):
# x.shape[1] is no longer a Python int, but a SymInt
batch_size = x.shape[0] # might be SymInt: s0
seq_len = x.shape[1] # might be SymInt: s1
hidden = x.shape[2] # if fixed, then int: 768
# SymInt supports arithmetic operations
output_size = seq_len * 2 # SymExpr: 2*s1
return x.reshape(batch_size, output_size, hidden // 2)
When TorchDynamo traces user Python code, it replaces concrete shape values with SymInts. All shape-related computations — parameters to reshape, view, permute, slice — become symbolic expressions. These expressions are recorded in the FX Graph and ultimately passed to the backend compiler.
Key design decisions of SymInt:
- Lazy evaluation: SymInt is not evaluated during tracing; concrete values are bound only at runtime
- Expression tracking: expressions like
s0 + s1,s0 * 3,s0 // 8are precisely tracked - Constraint propagation: new constraints are derived from known constraints (e.g., from
s0 > 0)
Guard System
Guards are PyTorch’s runtime type-checking mechanism. After torch.compile compiles a function, each subsequent call triggers a guard check: does the current input satisfy the assumptions made during compilation?
@torch.compile
def my_fn(x):
return x + 1
# First call: compile, shape=[4, 128, 768]
my_fn(torch.randn(4, 128, 768))
# Second call: guard check — does shape match?
# If same shape → cache hit → execute compiled code directly
my_fn(torch.randn(4, 128, 768))
# Third call: shape changed → guard fail → trigger recompilation
my_fn(torch.randn(4, 256, 768))
Guard types include:
- Shape Guard:
x.shape[0] == 4(most common, checks tensor shape) - Dtype Guard:
x.dtype == torch.float32(checks data type) - Device Guard:
x.device == cuda:0(checks device) - Value Guard: checks concrete Python variable values (less common)
Guard checks themselves are very fast — just a few integer comparisons. The real cost is the recompilation triggered by guard failures.
dynamic=False vs dynamic=None (automatic_dynamic_shapes)
The dynamic parameter of torch.compile controls dynamic shape handling:
dynamic=False (explicit static mode):
- Each concrete shape combination is compiled and cached separately
- Shape change → guard fail → recompile a new specialized kernel
- Compilation cache indexed by concrete shape:
{(4, 128, 768): kernel_1, (4, 256, 768): kernel_2, ...} - Many shape variants lead to excessive recompilation
dynamic=None (default, automatic_dynamic_shapes):
- First call is compiled with static shapes
- If a dimension’s shape change causes a guard failure, that dimension is automatically marked as symbolic
- The recompiled kernel is generic for that dimension — subsequent calls with any
seq_lenhit the cache - Dramatically reduces recompilation count (typically only 1 recompilation handles all shape variants)
This is a key PyTorch 2 design: automatically discovering which dimensions are dynamic, without requiring manual user annotation.
Mark Dynamic API
Users can also proactively tell the compiler which dimensions are dynamic:
x = torch.randn(4, 128, 768)
# Mark dimension 1 (seq_len) as dynamic
torch._dynamo.mark_dynamic(x, 1)
# Or use torch.compile's dynamic_shapes parameter
@torch.compile(dynamic=True)
def my_fn(x):
return x + 1
The effect of mark_dynamic: the compiler treats that dimension as a symbolic variable from the very first call, avoiding the initial guard failure and recompilation.
Symbol Constraint System
Symbolic variables are not unconstrained — PyTorch derives constraints from model structure and user hints:
# Example automatically derived constraints:
# s0 > 0 — dimensions must be positive
# s0 <= 2048 — derived from model's max_position_embeddings
# s0 % 8 == 0 — if model code has reshape(..., seq_len // 8, 8)
These constraints are passed to the backend compiler, narrowing the optimization search space. For instance, if the compiler knows s0 % 8 == 0, it can choose BLOCK_M = 8 or a multiple thereof without bounds checks.
Dynamic Shape Impact Across Compiler Stages
The impact of dynamic shapes is not isolated — it permeates every stage of the compiler stack. The interactive component below lets you explore stage-by-stage how static and dynamic shapes compare.
Impact on Graph Capture
During graph capture, TorchDynamo converts Python code into an intermediate representation (FX Graph). With static shapes, all dimensions are Python ints and the graph structure is deterministic. With dynamic shapes, dimensions become SymInts and shape operations become symbolic expressions.
Key effects:
- Branch elimination: with static shapes,
if x.shape[0] > 16: ...can be resolved during tracing, producing a single-path graph. With dynamic shapes, such shape-dependent control flow may require guards or dual-path graphs - Graph breaks: certain operations are unsupported under dynamic shapes, causing graph breaks (splitting one large graph into multiple smaller ones), increasing Python interpreter overhead
Impact on IR Representation
After graph capture, the computation graph is lowered to a backend IR (e.g., Inductor IR or MLIR). With static shapes, all dimensions are constants:
// Static IR
linalg.matmul
ins(%A: tensor<128x768xf32>, %B: tensor<768x512xf32>)
outs(%C: tensor<128x512xf32>)
With dynamic shapes, dimensions become symbolic variables or ? placeholders:
// Dynamic IR (MLIR style)
linalg.matmul
ins(%A: tensor<?x768xf32>, %B: tensor<768x512xf32>)
outs(%C: tensor<?x512xf32>)
MLIR’s RankedTensorType natively supports dynamic dimensions (using ?), allowing the IR to naturally express “known rank but unknown shape” tensors. Each ? dimension requires a runtime tensor.dim operation to obtain its value.
Loop bounds, stride calculations, buffer allocation sizes — all IR nodes that depend on dimension values transition from compile-time constants to runtime computations.
Impact on Optimization Passes
Optimization passes are most broadly affected by dynamic shapes. Here is a breakdown:
Passes that fail:
- Constant Folding:
alloc(128 * 512 * 4)can be computed at compile time asalloc(262144), butalloc(s0 * 512 * 4)must remain a runtime computation - Layout Optimization: the optimal memory layout may depend on specific dimension ratios (e.g., ratio), which are unknown under dynamic shapes
- Static Memory Planning: the compiler cannot pre-allocate buffers of exact size; must use runtime allocation or worst-case pre-allocation
- Loop Unrolling: cannot unroll loops when bounds are unknown
Passes that still work:
- Dead Code Elimination (DCE): removes unused computations, independent of shape
- Common Subexpression Elimination (CSE): merges duplicate computations; still effective for shape-independent portions
- Algebraic Simplification: , — algebraic identities are shape-independent
Impact on Fusion
The core decision in operator fusion is: can two adjacent operators be safely fused into a single kernel? This decision depends heavily on shape information.
SRAM capacity check: fusing two operators means intermediate results stay in SRAM rather than writing back to HBM. The compiler must verify that intermediate result sizes fit within SRAM capacity:
When is dynamic:
- Static:
14 KB < 164 KB→ safe to fuse - Dynamic:
s0 * 512 * 4 < 164 KB→ cannot determine → conservative strategy (don’t fuse or assume worst case)
Element-wise fusion is unaffected because it doesn’t need to buffer intermediate results:
# Regardless of shape, these operations can always be safely fused
y = relu(x + bias) # pointwise fusion always safe
Impact on Tiling and Code Generation
This is where dynamic shapes have the most direct, measurable impact.
Static Tiling:
# Optimal tile size determined at compile time
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 32
grid = (M // BLOCK_M, N // BLOCK_N, 1) # computable at compile time
Dynamic Tiling:
# Tile size must be chosen conservatively; grid computed at runtime
BLOCK_M = 64 # smaller value to accommodate multiple shapes
grid = (ceildiv(s0, BLOCK_M), N // BLOCK_N, 1) # runtime computation
Code generation impact is even more direct:
- Static kernel: the compiler knows exact dimensions, can eliminate all bounds checks, fully unroll loops, and select optimal vectorization width
- Dynamic kernel: must include masks (bounds checks), cannot unroll loops, and may suffer from warp divergence
# Static kernel — no mask needed
@triton.jit
def kernel_128_512(A_ptr, B_ptr, ...):
# Compiler knows all dimensions, direct access
a = tl.load(A_ptr + offs) # no mask needed
# Dynamic kernel — mask required
@triton.jit
def kernel_generic(A_ptr, B_ptr, N, ...):
mask = offs < N # bounds check
a = tl.load(A_ptr + offs, mask=mask, other=0.0) # masked load
The cost of masks:
- Extra instructions: each load/store requires additional comparison and mask operations
- Warp divergence: threads within the same warp may have different mask values, causing branch divergence
- Vectorization constraints: masks may break contiguous memory access patterns
Engineering Strategies
In the face of the dynamic shapes challenge, engineering practice has developed a range of strategies to balance compilation quality and flexibility.
Bucketing
Bucketing is the most widely used strategy. The core idea: instead of compiling a kernel for every specific seq_len, map seq_len to discrete “buckets,” each corresponding to a compiled kernel.
For example, map all seq_len values to {64, 128, 256, 512, 1024, 2048}. A request with seq_len=200 is mapped to bucket 256, requiring 56 tokens of padding — this is the “waste.”
Different bucketing strategies make different tradeoffs between compilation count and padding waste:
- No bucketing: each unique seq_len compiled separately; zero waste but most compilations
- Power-of-2: map to nearest power of 2; very few compilations but high waste for short sequences
- Fixed interval (e.g., multiples of 128): balanced approach
- Multiple-of-8: minimal padding, Tensor Core alignment friendly
Padding Strategies
Padding complements bucketing. After mapping a sequence length to a bucket, padding tokens fill the sequence to the bucket size. The padding strategy choice affects computational efficiency and memory usage:
Right Padding: the most common approach; add padding tokens at the end of the sequence. Combined with an attention mask, padding positions receive zero attention weight.
# Original sequence: [tok1, tok2, tok3] (len=3)
# Padded to 8: [tok1, tok2, tok3, PAD, PAD, PAD, PAD, PAD]
# Attention mask: [1, 1, 1, 0, 0, 0, 0, 0]
Alignment considerations: Tensor Cores (MMA instructions) typically require dimensions to be multiples of 8 or 16. Padding to multiples of 8 satisfies both Tensor Core alignment requirements and ensures memory coalescing.
Shape Hints
Shape hints allow users to provide range information about dynamic dimensions, helping the compiler find a middle ground between “fully static” and “fully dynamic.”
# torch.export with dynamic shapes
from torch.export import export, Dim
batch = Dim("batch", min=1, max=32)
seq_len = Dim("seq_len", min=1, max=2048)
exported = export(
model,
(sample_input,),
dynamic_shapes={"x": {0: batch, 1: seq_len}},
)
With these hints, the compiler can:
- Range-based optimization: knowing
seq_len <= 2048enables fixed-size buffer pre-allocation - Partition optimization: divide the range into intervals, each using a different tile configuration
- Eliminate some bounds checks: if
seq_len % 8 == 0, masks can be omitted
AOT vs JIT Compilation
Facing dynamic shapes, AOT (Ahead-of-Time) and JIT (Just-in-Time) compilation have different strengths:
AOT compilation:
- Pre-compile kernels for all expected shape variants
- Advantage: zero compilation latency during inference, predictable latency
- Disadvantage: requires knowing the shape distribution in advance; long compile time; may miss rare shapes
# AOT: pre-compile kernels for all bucket sizes
for bucket_size in [64, 128, 256, 512, 1024, 2048]:
compile_kernel(model, seq_len=bucket_size)
JIT compilation:
- Compile dynamically on first encounter of a new shape; cache for subsequent use
- Advantage: flexible, automatically adapts to any shape, no need to predict distribution
- Disadvantage: compilation latency on first encounter of each new shape (cold start)
# JIT: default behavior of torch.compile
@torch.compile
def model_fn(x):
return model(x)
# First call compiles; subsequent calls execute directly if shape matches
Hybrid strategy (most common in production):
- AOT pre-compile for common shapes (e.g., bucket sizes)
- JIT as fallback for rare shapes
- Background async compilation to mitigate cold-start impact
Practical Analysis
Scenario: LLM Inference Service
Consider an LLM inference service with the following request distribution:
| seq_len Range | Request Share | Typical Use Case |
|---|---|---|
| 1–32 | 20% | Short Q&A, completions |
| 33–128 | 35% | Normal conversation |
| 129–512 | 30% | Long conversations, summarization |
| 513–2048 | 15% | Long document analysis |
Strategy Comparison
Strategy A: torch.compile(dynamic=False)
Every unique seq_len triggers a compilation. With 100 different seq_len values, that means 100 compilations. Under high traffic, compilation overhead can account for over 30% of total latency.
Strategy B: torch.compile(dynamic=None) + default behavior
After the first guard failure, the seq_len dimension is automatically marked symbolic. All subsequent seq_len values hit the cache. Only 2 compilations total (initial + 1 recompilation). Performance is approximately 10–15% lower than Strategy A (because the kernel is generic, not specialized).
Strategy C: Bucketing + AOT pre-compilation
Map seq_len to {32, 64, 128, 256, 512, 1024, 2048} — seven buckets. AOT pre-compile 7 specialized kernels. Each kernel is highly optimized for its specific shape.
| Metric | Strategy A | Strategy B | Strategy C |
|---|---|---|---|
| Compilations | ~100 | 2 | 7 (pre-compiled) |
| Inference latency | Lowest (specialized) | Medium (generic) | Low (specialized) |
| First-request latency | High (compiling) | High (compiling) | Low (pre-compiled) |
| Padding waste | 0% | 0% | ~15% average |
| Memory usage | High (many cached kernels) | Low (1–2 kernels) | Medium (7 kernels) |
Common Pitfalls
-
Forgetting
mark_dynamic: the model has shape-dependent control flow (if x.shape[0] > 16), but dynamic dimensions aren’t marked, causing a graph break on every shape change -
Too many guards: each model layer has independent guard checks; some intermediate tensor shapes indirectly depend on input shapes, leading to excessively long guard chains
-
Shape-dependent control flow:
# Dangerous: shape-dependent control flow
def forward(self, x):
if x.shape[1] > 512: # takes different branch based on seq_len
return self.long_path(x)
return self.short_path(x)
This causes the compiler to generate different graphs for each branch condition, creating frequent graph breaks when shapes change. The recommended approach is a unified path with mask-based control.
- Dynamic shape + data-dependent shape:
# Extreme case: output shape depends on input values, not just input shape
indices = (x > 0).nonzero() # output shape depends on the number of positive values in x
# This is extremely difficult for compilers to handle
Summary
Dynamic shapes represent a systematic challenge for ML compilers — affecting every stage of the compiler stack:
- Graph capture: SymInt replaces concrete values; the guard system controls recompilation
- IR representation: dynamic dimensions are represented as symbolic variables or
?; loop bounds become runtime values - Optimization passes: constant folding, layout optimization, static memory planning — key passes partially or fully disabled
- Operator fusion: SRAM capacity checks become indeterminate; fusion strategy forced to be conservative
- Tiling: tile sizes cannot be optimized at compile time; grid dimensions require runtime computation
- Code generation: generic kernels require bounds checks (masks), potentially causing warp divergence
PyTorch 2’s SymInt/Guard system provides a practical solution:
dynamic=None(default) uses automatic_dynamic_shapes to discover dynamic dimensions- The guard system enables runtime checks with minimal overhead
- Symbol constraints propagate dynamic dimension knowledge to backend compilers
Engineering strategies (bucketing, padding, shape hints, AOT/JIT hybrid) bridge the compiler’s gaps in practice, finding balance between compilation quality and flexibility through reasonable tradeoffs.
Looking ahead: dynamic shape handling continues to evolve rapidly. Better symbolic reasoning (propagating more information through the constraint system), automatic bucketing (the compiler automatically selecting bucket sizes based on runtime shape distributions), and shape-polymorphic kernels (a single kernel adapting to multiple shapes through a few runtime parameters) are the primary directions of development.
The next article dives into code generation — translating optimized IR into efficient hardware instructions, the final output of the compiler stack.
Further Reading
- PyTorch 2 paper (ASPLOS 2024): comprehensive treatment of TorchDynamo, AOTAutograd, and the dynamic compilation system
- torch.compile dynamic shapes documentation: official usage guide for
torch.compiler_dynamic_shapes - TorchDynamo deep dive: internals of the guard system implementation
- MLIR documentation — RankedTensorType: design of the dynamic dimension type system in MLIR