Quantization Compilation and Mixed-Precision Optimization
Updated 2026-04-13
Introduction
In the quantization learning path, we have already discussed quantization fundamentals (numerical formats and quantization math), post-training weight quantization (GPTQ, AWQ, and other methods), and inference-time quantization (dynamic quantization and KV cache compression). Those articles answer the core question of “how to quantize” — how to compress FP32/FP16 weights and activations to INT8, INT4, or even lower precision.
This article shifts perspective and answers an equally critical question from the compiler’s viewpoint: How does the compiler handle the computation graph after quantization?
The importance of this question is often underestimated. A carefully quantized model may yield performance gains far below theoretical values if the compiler cannot properly optimize its execution graph. Specifically, the compiler must solve three core challenges:
- Fusion: Quantization introduces numerous
Dequant→Compute→Quant“sandwich” structures. If each dequant/quant executes as an independent operator, type conversion overhead can negate quantization benefits. The compiler must identify and fuse these patterns. - Mixed-precision routing: A single Transformer layer may contain INT4 weights, FP16 activations, FP32 LayerNorm, and FP32 Softmax simultaneously. The compiler needs to propagate precision information through the graph, insert minimal type conversion nodes, while ensuring numerical stability.
- Kernel generation: Different quantization formats require fundamentally different compute kernels. Weight-only INT4 needs on-the-fly dequant + FP16 MatMul; W8A8 needs integer GEMM + requantization; FP8 on Hopper architecture has dedicated Tensor Core instructions. The compiler must select or generate optimal kernels for each combination.
Building on the knowledge from code generation and the Triton backend, this article dives deep into the full compilation optimization pipeline for quantized graphs. We cover Quantization-Aware Fusion, Mixed-Precision Compilation strategies, and Quantized Kernel Generation, with interactive visualizations to help build intuitive understanding.
Compilation Challenges for Quantized Graphs
The Dequant-Compute-Quant Sandwich Structure
When a model undergoes PTQ (Post-Training Quantization), its computation graph changes fundamentally. Consider a simple Linear → ReLU → Linear example:
The computation graph before quantization is straightforward:
X(FP16) → MatMul(W1) → ReLU → MatMul(W2) → Y(FP16)
After quantization (assuming a W4A16 scheme), the graph balloons in size:
X(FP16) → Dequant(W1_INT4→FP16) → MatMul → Quant(FP16→INT8)
→ Dequant(INT8→FP16) → ReLU → Quant(FP16→INT8)
→ Dequant(INT8→FP16) → Dequant(W2_INT4→FP16) → MatMul → Y(FP16)
This structure is called the Dequant-Compute-Quant sandwich (sandwich pattern). Each computation operation is surrounded by quantize/dequantize nodes on both sides. If the compiler naively executes these operators in topological order, severe performance issues arise:
- Memory bandwidth waste: Each dequant/quant requires reading and writing intermediate tensors, generating massive DRAM round-trips. On an A100, a single DRAM read/write has 10-100x the latency of a Tensor Core computation.
- Kernel launch overhead: Each independent operator corresponds to a CUDA kernel launch, with the GPU sitting idle between launches.
- Type conversion compute overhead: INT4→FP16 dequant requires unpacking + scale multiplication + zero-point addition — all element-wise operations with extremely low compute density.
Consider a 4096x4096 INT4 Linear layer. The weight size is 4096x4096x0.5 = 8MB (INT4 packed). If dequant executes as an independent kernel:
- Read 8MB INT4 data
- Write out 32MB FP16 data
- Total memory access: 40MB
If dequant is fused into the MatMul kernel:
- Read 8MB INT4 data (convert on-the-fly in registers)
- Feed directly to Tensor Core
- Eliminate the 32MB intermediate write
This single step saves 80% of memory bandwidth.
Graph-Level Complexity of Mixed Precision
A real Transformer inference graph is not single-precision — it is a “precision mosaic.” A typical LLaMA-2 7B inference configuration might look like:
| Component | Weight Precision | Activation Precision | Accumulation Precision |
|---|---|---|---|
| Embedding | FP16 | FP16 | — |
| QKV Projection | INT4 (GPTQ) | FP16 | FP32 |
| Attention Score | — | FP16 | FP32 |
| Softmax | — | FP32 | FP32 |
| KV Cache | — | FP8/INT8 | — |
| Output Projection | INT4 (GPTQ) | FP16 | FP32 |
| LayerNorm | FP32 | FP32 | FP64* |
| FFN Gate/Up | INT4 (GPTQ) | FP16 | FP32 |
| FFN Down | INT4 (GPTQ) | FP16 | FP32 |
| SiLU/GeLU | — | FP16 | — |
(*LayerNorm variance computation uses FP64 accumulation in some implementations, but most inference frameworks use FP32 accumulation + Welford’s online algorithm to ensure numerical stability.)
Facing this precision map, the compiler must answer a series of questions:
- Precision Propagation: Starting from the known precision of leaf nodes (weights and inputs), how to derive the data type of each edge? When two tensors of different precision meet (e.g., INT4 weight x FP16 activation), which precision takes precedence?
- Type Conversion Minimization: Subject to numerical correctness, how to insert the fewest possible cast nodes? This is a constrained optimization problem.
- Safe Precision Regions: Which operations can safely execute at low precision (e.g., INT8 ReLU is completely lossless), and which must maintain high precision (e.g., Softmax exponentials, LayerNorm variance computation)?
- Hardware Constraints: Which precision combinations do Tensor Cores support? Is INT4xFP16 natively supported, or does it require software dequant?
The answers to these questions determine the actual inference performance of the quantized model. The latency gap between a well-optimized compiler and a naive one can be 2-4x on the same quantized model.
Quantifying Type Conversion Overhead
To understand the necessity of compiler optimization, let us quantify the actual overhead of type conversions. On NVIDIA A100:
- INT4 → FP16 dequant: Requires bit-shift + mask (unpacking 2 INT4s from 1 byte), multiply by scale, add zero_point. Approximately 5-8 FLOPs per element, with extremely low arithmetic intensity (~0.1 FLOP/byte), completely memory-bound.
- FP16 → INT8 quant: Requires divide by scale + round + clamp. Approximately 4 FLOPs per element, also memory-bound.
- FP32 → FP16 cast: Natively supported in hardware, nearly zero overhead (completed inside Tensor Core).
- FP16 → FP32 cast: Same as above, zero overhead.
Key observation: INT4/INT8-related quantize/dequantize operations are compute-trivial but memory-heavy. Their computation is negligible, but the memory traffic they generate is not. This makes them ideal targets for compiler fusion — absorbing them into compute-intensive MatMul kernels, using MatMul’s high arithmetic intensity to “mask” the type conversion overhead.
Quantization-Aware Fusion
Quantization-Aware Fusion is the core optimization in quantization compilation. Its goal: identify dequant-compute-quant patterns, absorb type conversions into compute kernels, and eliminate intermediate tensor memory round-trips.
Pattern 1: Weight Dequant Fusion
This is the most basic and most important fusion pattern. In weight-only quantization (GPTQ, AWQ, etc.), weights are stored in INT4/INT8 and need dequantization to FP16 before participating in MatMul.
Before fusion:
W_int4 → Dequant → W_fp16 ─┐
├─ MatMul → Y_fp16
X_fp16 ─────────────────────┘
After fusion:
W_int4 ─┐
├─ FusedDequantMatMul → Y_fp16
X_fp16 ─┘
The fused kernel implementation strategy:
- Load phase: Load INT4 packed data from global memory to shared memory. Each thread is responsible for unpacking its corresponding weight elements: shift, mask, scale, offset — all completed in registers.
- Compute phase: Dequantized FP16 values are fed directly into Tensor Core
mmainstructions (wmma::mma_sync). - Store phase: MatMul results are written directly with no intermediate tensors.
This fusion is a standard implementation in TensorRT, vLLM (via Marlin kernel), and llama.cpp. Taking a 4096x4096 MatMul as an example:
- Pre-fusion memory access: Read 8MB INT4 + Write 32MB FP16 + Read 32MB FP16 + Read 32MB X + Write 32MB Y = 136MB
- Post-fusion memory access: Read 8MB INT4 + Read 32MB X + Write 32MB Y = 72MB
- Savings: 47% memory bandwidth
Pattern 2: Epilogue Quant Fusion
When MatMul output needs immediate quantization (e.g., in W8A8 schemes where the next layer expects INT8 input), the compiler can fuse the quantize operation as a MatMul epilogue:
Before fusion:
X → MatMul → BiasAdd → ReLU → Quant → Y_int8
After fusion:
X → FusedMatMulBiasReLUQuant → Y_int8
This fusion pattern is called Post-Processing Fusion or Epilogue Fusion in TensorRT. After Tensor Core completes matrix multiplication, results remain in registers, and subsequent bias add, activation function, quantize (scale + round + clamp) all complete at the register level. Only the final INT8 result is written to global memory.
Key advantages:
- Eliminates intermediate FP16/FP32 tensor writes and reads: Saves 3 memory round-trips
- ReLU/GeLU quantization friendliness: Post-ReLU values are non-negative, allowing unsigned INT8 usage, with range shifting from [-128, 127] to [0, 255] — effectively doubling precision
- Rounding mode optimization: Fused kernels can select optimal rounding (e.g., stochastic rounding), while standalone kernels typically use standard round-to-nearest-even
TensorRT’s IQuantizeLayer + IConvolutionLayer automatic fusion exemplifies this pattern. The compiler uses pattern matching to identify MatMul → [optional BiasAdd] → [optional Activation] → Quantize sequences and replaces them with a single fused kernel.
Pattern 3: Full Dequant-Compute-Quant Fusion
The most aggressive fusion pattern merges input dequant, computation, and output quant entirely:
Before fusion:
X_int8 → Dequant → X_fp16 ─┐
├─ MatMul → Y_fp16 → LayerNorm → Quant → Y_int8
W_int4 → Dequant → W_fp16 ─┘
After fusion:
X_int8 ─┐
├─ FusedDequantMatMulLNQuant → Y_int8
W_int4 ─┘
This fusion eliminates all intermediate FP16 tensors — the entire subgraph’s inputs and outputs are low-precision integers. This is the ultimate optimization for W8A8 (SmoothQuant-style) schemes.
However, full fusion has important constraints:
- LayerNorm precision requirements: LayerNorm’s variance computation needs FP32 accumulation precision. The fused kernel must maintain an FP32 reduction buffer internally.
- Scale parameter handling: Dequant and quant scale/zero_point parameters must be passed as kernel arguments, increasing kernel parameter complexity.
- Tile size constraints: LayerNorm requires full-row reduction, which constrains tiling strategy — the row dimension must be processed completely.
- Correctness verification: Multi-step fusion increases the risk of numerical error accumulation. The compiler must prove that the fused numerical results fall within acceptable error bounds.
TensorRT uses its myelin compiler backend to search for such large-scale fusion opportunities. It evaluates fusion benefits through a cost model (memory savings vs. kernel complexity), selecting the fusion scheme with the greatest benefit.
Fusion Decision Cost Model
Compilers do not blindly fuse every possible pattern. The benefits and costs of fusion must be weighed:
Benefits:
- Reduced kernel launch count
- Eliminated intermediate tensor memory round-trips
- Increased data reuse in registers/shared memory
Costs:
- Increased register pressure in the single kernel
- Reduced GPU occupancy
- Increased compilation time
TensorRT’s cost model considers:
- Arithmetic intensity change: Does the fused kernel remain compute-bound? If fusion changes the kernel from compute-bound to memory-bound (due to register spills from insufficient registers), fusion may be a negative optimization.
- Occupancy impact: Fused kernels use more registers/shared memory, potentially reducing warps per SM and degrading latency hiding.
- Data reuse opportunities: If an intermediate tensor has multiple consumers (fan-out > 1), fusion leads to redundant computation.
A practical heuristic: When the intermediate tensor size exceeds L2 cache capacity, fusion is almost always worthwhile. This is because not fusing means DRAM round-trips, and DRAM bandwidth is the GPU’s scarcest resource. On A100, L2 cache is 40MB. A 4096x4096 FP16 tensor is 32MB, near the tipping point; an 8192x8192 tensor at 128MB far exceeds L2, making fusion enormously beneficial.
Mixed-Precision Compilation Strategies
Mixed-Precision Compilation addresses a higher-level problem: across the entire computation graph, how to select optimal precision for each operation, and how to efficiently insert type conversions at precision boundaries.
Precision Propagation Algorithm
Compilers use a Precision Propagation algorithm to determine the data type of each edge in the graph. This process resembles type inference but incorporates precision constraints:
Step 1: Annotate leaf node precision
Starting from the known precision of weights and inputs:
- GPTQ-quantized weights are annotated as INT4
- Activations are annotated as FP16 (or INT8, depending on the quantization scheme)
- LayerNorm weight/bias typically remain FP32
Step 2: Forward propagation
For each operator, determine output precision based on semantics:
- MatMul(INT4, FP16): Output FP16 (INT4 must dequant to FP16 first)
- MatMul(INT8, INT8): Output INT32 (natural output of integer GEMM)
- Softmax(any): Output at least FP16 (numerical requirement)
- LayerNorm(any): Output at least FP16 (variance computation requirement)
- ReLU(any): Preserve input precision (no precision loss operation)
Step 3: Backward propagation of constraints
Some operations impose strong constraints on input precision:
- Softmax input must be >= FP16 (otherwise overflows)
- LayerNorm internal accumulation must be >= FP32
- Final output (logits) typically requires FP16
Step 4: Conflict resolution
When forward and backward propagation conflict:
- Higher precision takes precedence (conservative strategy)
- Insert cast nodes to resolve precision mismatches
Actual compiler implementations (such as TensorRT) use Calibration data to assist precision decisions. By running the model on a small calibration dataset, measuring the activation distribution per layer (range, outlier ratio), the compiler decides which layers can safely lower precision and which must maintain high precision.
AMP and Compiler Cooperation
PyTorch’s AMP (Automatic Mixed Precision) and compiler mixed-precision optimization share similar goals but operate at different levels:
| Feature | PyTorch AMP | Compiler (TensorRT/TorchInductor) |
|---|---|---|
| Granularity | Operator-level | Subgraph-level |
| Precision selection | Whitelist/blacklist | Cost model + Calibration |
| Cast insertion | At every operator boundary | Minimum set after fusion |
| Runtime overhead | Present (autocast checks) | Zero (determined at compile time) |
The compiler’s advantage lies in its access to global information. For example, in a MatMul → ReLU → MatMul chain where both MatMuls are FP16, the intermediate ReLU needs no casts — while AMP’s per-operator strategy might insert a cast at ReLU’s input and output.
torch.compile when processing quantized models performs precision propagation during the graph lowering phase. The TorchInductor backend recognizes quantization-related ops like torch.ops.quantized_decomposed.quantize_per_tensor, treating them as precision annotations rather than independent compute nodes. In subsequent fusion and codegen phases, these annotations guide kernel precision selection.
The Precision-Performance-Quality Triangle
The core challenge in mixed-precision strategy is balancing three objectives:
- Quality: Numerical accuracy of model outputs, typically measured by perplexity (language models) or accuracy (classification models).
- Throughput: Tokens processed per second or inferences per second.
- Memory: GPU memory footprint, determining the maximum deployable model size and batch size.
Fundamental trade-offs exist among these three:
- Lower precision → Reduced memory, higher throughput, but quality may degrade
- Maintain high precision → Quality guaranteed, but memory and throughput are constrained
- Mixed precision → Maintain high precision on critical paths, lower precision on tolerant paths — this is the optimal solution in engineering practice
The compiler’s role in this triangle: maximize performance under user-specified precision constraints. Users specify target precision for each layer through the quantization scheme (e.g., W4A16), and the compiler ensures correct execution of the precision scheme while minimizing execution overhead.
Specifically, compiler precision decisions follow these priorities:
- Safety first: Numerically sensitive operations like Softmax and LayerNorm never drop below FP16
- Hardware alignment: Prefer Tensor Core supported precision combinations (FP16xFP16, INT8xINT8, FP8xFP8)
- Minimum casts: After satisfying the above constraints, choose the scheme with fewest cast node insertions
- Bandwidth optimal: When cast node counts are equal, choose the scheme with smallest intermediate tensor memory footprint
Practical Example: W4A16 Precision Routing
Using W4A16 (INT4 weights, FP16 activations) as an example, let us analyze the compiler’s precision routing decisions:
Input(FP16) → Embedding(FP16) → [TransformerBlock x N] → LMHead(FP16) → Logits(FP16)
TransformerBlock:
RMSNorm(FP16, internal FP32)
→ QKV_Proj: Dequant(W_INT4→FP16) → MatMul(FP16xFP16→FP32→FP16)
→ Attention: Softmax(FP16, internal FP32) → AttnMatMul(FP16)
→ O_Proj: Dequant(W_INT4→FP16) → MatMul(FP16xFP16→FP32→FP16)
→ Residual Add(FP16)
→ RMSNorm(FP16, internal FP32)
→ Gate_Up: Dequant(W_INT4→FP16) → MatMul(FP16xFP16→FP32→FP16)
→ SiLU x element_mul
→ Down: Dequant(W_INT4→FP16) → MatMul(FP16xFP16→FP32→FP16)
→ Residual Add(FP16)
In this graph, the compiler decides:
- All
Dequant(W_INT4→FP16)fused into subsequent MatMul → 4 dequants eliminated - MatMul internal accumulation uses FP32 (handled automatically by Tensor Core)
- Softmax internally uses FP32 + online softmax algorithm
- RMSNorm internally uses FP32 accumulation
- No explicit cast nodes needed (activations remain FP16 throughout)
The resulting execution graph is much more concise than the original quantized graph. Each TransformerBlock is compressed from ~20 operators to ~10.
Quantized Kernel Generation
Different quantization schemes require fundamentally different compute kernels. The compiler must generate efficient specialized kernels for each quantization format, or select optimal implementations from a kernel template library.
Weight-Only INT4 Kernel
Weight-only INT4 (W4A16) is currently the most popular quantization scheme for LLM inference. Its kernel structure:
# Pseudocode: INT4 Weight-Only MatMul Kernel
@triton.jit
def w4a16_matmul(
x_ptr, # FP16 activation [M, K]
w_packed_ptr, # INT4 packed weight [K, N/2] (2 INT4 per byte)
scale_ptr, # FP16 scale [K/group_size, N]
zeros_ptr, # INT4 zero_point [K/group_size, N/2]
output_ptr, # FP16 output [M, N]
...
):
# 1. Load packed INT4 weights (2 values per byte)
w_packed = tl.load(w_packed_ptr + offsets)
# 2. Unpack: shift + mask
w_lo = w_packed & 0x0F # lower 4 bits
w_hi = (w_packed >> 4) & 0x0F # upper 4 bits
# 3. Dequantize: (w_int4 - zero_point) * scale
w_fp16 = ((w_lo - zeros) * scales).to(tl.float16)
# 4. MatMul with Tensor Core
acc = tl.dot(x_tile, w_fp16) # FP16 x FP16 → FP32 accumulate
# 5. Store result
tl.store(output_ptr + out_offsets, acc.to(tl.float16))
Key performance considerations:
- Unpack overhead: INT4 unpacking requires shift+mask, with 1 byte of packed data yielding two elements. In Triton, these bit operations complete entirely in registers with minimal overhead.
- Group Quantization: Scale and zero_point are typically shared per group (e.g., 128 elements per group). This means every 128 dequant operations only load one set of scale/zero_point, amortizing scale read costs.
- Memory bandwidth savings: INT4 weights are 1/4 the size of FP16. For memory-bound inference scenarios (batch size = 1), INT4’s theoretical speedup is ~3.5-4x over FP16 (accounting for scale and unpack overhead).
NVIDIA Marlin Kernel: The high-performance W4A16 kernel used in vLLM achieves near-theoretical INT4 bandwidth utilization through:
- 128-bit
LDG.128instructions loading 32 INT4 values at once - Carefully designed thread-to-data mapping eliminating bank conflicts
- Asynchronous loads (
cp.async) hiding global memory latency - Weight permutation matching Tensor Core data layouts
W8A8 Integer Kernel
W8A8 (SmoothQuant-style) uses integer GEMM with native Tensor Core support at the hardware level:
INT8 activation x INT8 weight → INT32 accumulate → requantize → INT8 output
This is the only scheme that truly leverages INT8 Tensor Core instructions (imma). Its kernel structure:
- Load: Directly load INT8 data from global memory (no unpacking needed)
- Compute: Tensor Core executes
INT8 x INT8 → INT32matrix multiplication - Requantize: INT32 accumulated result x output_scale → round → clamp → INT8
- Store: INT8 result written out
Performance advantages:
- Tensor Core utilization: INT8 TC peak throughput is 2x FP16 TC (A100: 624 vs 312 TOPS)
- Memory bandwidth: INT8 data is half the size of FP16
- No dequant overhead: Both input and output are INT8 with no precision conversion
However, W8A8 has significant accuracy challenges:
- Activation INT8 quantization is much harder than weight quantization (activation distributions are more dynamic, with more outliers)
- Requires SmoothQuant or similar techniques to balance weight and activation quantization difficulty
- Per-token dynamic quantization requires additional scale computation kernels
FP8 Kernel (Hopper Architecture)
NVIDIA H100 (Hopper architecture) introduced FP8 Tensor Core support, offering two FP8 formats:
- E4M3: 4-bit exponent + 3-bit mantissa, range [-448, 448], precision ~0.0625
- E5M2: 5-bit exponent + 2-bit mantissa, range [-57344, 57344], precision ~0.25
FP8’s core advantage is the best precision-performance balance point:
| Property | FP16 | INT8 | FP8 (E4M3) |
|---|---|---|---|
| H100 TC Throughput | 1,979 TOPS | 1,979 TOPS | 1,979 TOPS |
| Memory Bandwidth Savings | 1x | 2x | 2x |
| Dynamic Range | +/-65504 | [-128, 127] | [-448, 448] |
| Quantization Difficulty | — | High (needs calibration) | Medium (FP format naturally handles outliers) |
FP8 E4M3 on H100 achieves the same Tensor Core throughput as INT8, but because it is a floating-point format, it naturally supports dynamic range scaling, making quantization difficulty far lower than INT8.
Compiler support for FP8 kernels:
# TensorRT-LLM FP8 GEMM call
# Compiler auto-generates the following cuBLAS call
cublasLtMatmul(
handle,
matmulDesc, # CUBLAS_COMPUTE_32F
alpha_ptr, # FP32 scale
A_ptr, # FP8 E4M3
Adesc, # CUDA_R_8F_E4M3
B_ptr, # FP8 E4M3
Bdesc, # CUDA_R_8F_E4M3
beta_ptr,
C_ptr, # FP16 or BF16
Cdesc,
D_ptr, # output
Ddesc,
...
)
Key compiler decisions in FP8 kernel generation:
- E4M3 vs E5M2: Forward inference uses E4M3 (precision-first), backward pass (if any) uses E5M2 (range-first)
- Scale strategy: Per-tensor scale (fastest) vs per-channel scale (more accurate)
- Output precision: FP8 GEMM output can be FP16, BF16, or FP32, depending on the precision requirements of subsequent operations
Triton’s Quantized Kernel Support
Triton, as TorchInductor’s backend, supports quantized kernels primarily through:
- Bit manipulation primitives:
tl.loadsupports arbitrary bit-width data types; the Triton compiler compiles INT4 unpacking into efficient shift+mask PTX instructions - Mixed-type dot:
tl.dot(a, b)supports different input types (e.g., FP16 x FP16 with FP32 accumulate); the compiler automatically selects the correspondingmmainstruction - Custom dequant fusion: Through Triton’s kernel fusion pass,
tl.load+ bit operations +tl.dotcan be fused into a single load-compute pipeline
The torch.compile + TorchInductor workflow for quantized models:
- Dynamo captures the FX graph containing ops like
quantize_per_tensor,dequantize_per_tensor - TorchInductor’s lowering pass annotates these ops as “quantization_annotation”
- The fusion pass identifies dequant-matmul patterns and generates fused Triton kernels
- The Triton compiler completes final PTX code generation
As of 2025, torch.compile’s W4A16 support is quite mature (via torch.ops.aten._weight_int4pack_mm), but compiler paths for W8A8 and FP8 are still under active development. TensorRT-LLM provides more complete compilation support for these advanced quantization formats.
Practical Case: LLaMA 7B Quantization Compilation Pipeline
Let us use a concrete example to tie together all the concepts discussed above. Suppose we want to deploy LLaMA-2 7B on a single A100 80GB GPU, with the goal of maximizing single-card throughput.
Step 1: Quantization Scheme Selection
| Scheme | Model Size | Perplexity Degradation | Inference Engine Support |
|---|---|---|---|
| FP16 (baseline) | 14GB | 0 | All |
| GPTQ W4A16 g128 | 3.5GB | +0.1-0.3 | vLLM, TensorRT-LLM |
| AWQ W4A16 g128 | 3.5GB | +0.05-0.2 | vLLM, TensorRT-LLM |
| SmoothQuant W8A8 | 7GB | +0.05-0.1 | TensorRT-LLM |
We choose AWQ W4A16 g128 — the best balance of memory, quality, and compatibility.
Step 2: Compiler Optimization Pipeline
Using TensorRT-LLM as an example, the compiler performs the following optimizations on the AWQ-quantized LLaMA 7B:
a) Graph Import & Annotation
TensorRT-LLM loads weights from the HuggingFace quantized model and constructs the TensorRT network definition. Each Linear layer is represented as:
INT4_weight → IConstantLayer
scale/zeros → IConstantLayer
Input(FP16) → IDequantizeLayer → IMatMulLayer → ...
b) Quantization-Aware Fusion
TensorRT’s optimizer performs these fusions:
- 32 Transformer blocks x 4 Linear layers = 128 DequantMatMul fusions
- Each block’s QKV projection further fuses into QKV fused MatMul (3 MatMuls merged into 1)
- Attention scale + softmax fusion
- FFN Gate*Up element-wise multiplication fusion
c) Precision Routing
The compiler determines full-graph precision:
- All MatMuls: INT4→FP16 dequant inside TC, FP32 accumulate
- Softmax: FP32 internal
- RMSNorm: FP32 internal
- Residual add: FP16
- All intermediate activations: FP16
d) Kernel Selection
For each fused operator, the compiler selects the optimal implementation from the kernel library:
- INT4 MatMul → Marlin-style W4A16 kernel
- FlashAttention → FA2 kernel with FP16 Q/K/V
- RMSNorm → fused kernel with FP32 variance computation
Step 3: Performance Data
Typical performance data on A100 80GB (batch size = 1, sequence length = 2048):
| Configuration | Prefill (tok/s) | Decode (tok/s) | Memory Usage |
|---|---|---|---|
| FP16 (baseline) | ~4000 | ~40 | 14GB |
| AWQ W4A16 (naive) | ~5000 | ~80 | 3.5GB |
| AWQ W4A16 (optimized) | ~6000 | ~140 | 3.5GB |
- Naive: No fusion, each dequant as an independent kernel
- Optimized: Full fusion + optimized kernels
Key observations:
- The decode phase (memory-bound) benefits most from fusion: 40 → 140 tok/s (3.5x)
- The prefill phase (compute-bound) benefits less: 4000 → 6000 tok/s (1.5x)
- Significant memory reduction enables larger batch sizes or longer contexts
Step 4: Compiler Optimization Contribution Breakdown
The decode speedup from AWQ W4A16 (optimized) vs FP16 baseline (40 → 140 tok/s = 3.5x) can be decomposed:
- Weight size reduction (4x): INT4 weights are 1/4 of FP16 → theoretical 4x bandwidth savings → but requires dequant overhead
- Dequant fusion (~0.9x overhead): Post-fusion dequant adds approximately 10% overhead
- Actual speedup: 4x x 0.9 x 0.97 (other overhead) ≈ 3.5x
Without fusion (naive), the independent dequant kernel overhead is approximately 40%, yielding actual speedup of only 4x x 0.6 ≈ 2.4x. Compiler fusion contributes approximately 45% additional speedup (from 2.4x to 3.5x).
Summary and Outlook
Quantization compilation is the critical bridge connecting algorithms (quantization methods) and hardware (Tensor Core, DRAM bandwidth). This article covered three core topics:
- Quantization-Aware Fusion: Absorbing dequant/quant nodes into compute kernels, eliminating intermediate tensor memory round-trips. Three fusion patterns — Weight Dequant Fusion, Epilogue Quant Fusion, and Full Sandwich Fusion — each suit different quantization schemes.
- Mixed-Precision Compilation: Using precision propagation algorithms to determine full-graph precision allocation, maintaining high precision in safety-critical operations (Softmax, LayerNorm) while minimizing type conversion overhead.
- Quantized Kernel Generation: For three mainstream schemes — W4A16 (INT4 dequant + FP16 TC), W8A8 (INT8 TC + requantize), and FP8 (H100 native) — the compiler must select or generate fundamentally different kernel implementations.
These optimizations make a significant practical contribution to end-to-end inference performance. Using LLaMA 7B AWQ W4A16 as an example, compiler optimization boosts decode throughput from a naive ~80 tok/s to ~140 tok/s.
Outlook:
- FP4 (NF4) compilation support: As techniques like QLoRA gain popularity, compilation optimization for FP4 formats will become the next frontier
- Dynamic quantization compiler integration: Most current compilers assume static quantization parameters; supporting per-token dynamic quantization requires deep compiler-runtime integration
- Speculative decoding + quantization: Combining quantization compilation with speculative decoding to further boost inference throughput
The next article will explore distributed compilation — when models are deployed across multiple GPUs/machines, how the compiler handles communication-computation overlap optimization for tensor parallelism and pipeline parallelism.
Further Reading
- Gholami et al., “A Survey of Quantization Methods for Efficient Neural Network Inference” (2021) — comprehensive survey of quantization methods
- Frantar et al., “GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers” (2022) — detailed GPTQ algorithm
- Lin et al., “AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration” (2023) — AWQ algorithm and kernel design
- Micikevicius et al., “FP8 Formats for Deep Learning” (2022) — FP8 format design and hardware support
- NVIDIA TensorRT Developer Guide — industrial practice of quantization compilation
- PyTorch Quantization Documentation — quantization compilation support in the PyTorch ecosystem