GEMM Optimization — From Naive to Peak Performance
Updated 2026-04-06
GEMM (General Matrix Multiply) is the dominant compute operation in LLM training and inference. This article starts from the simplest naive implementation and progressively adds optimizations — tiling, thread tiling, vectorized memory access, double buffering, Tensor Core — until we approach cuBLAS-level theoretical peak performance.
1. Why GEMM Is the Core of LLMs
Every layer of a Transformer contains multiple matrix multiplications: QKV projections, attention scores, output projection, and the two linear layers of the FFN. For a typical LLM (e.g., hidden_dim=4096), over 90% of compute comes from these GEMM operations.
GEMM computation is 2MNK FLOPs (where M, N, K are the dimensions of matrices A, B, C respectively). For large matrices, the arithmetic intensity is well above the roofline’s ridge point, making it compute-bound — meaning in theory we can fully utilize the GPU’s compute units.
2. Naive Implementation — Baseline
The simplest CUDA GEMM: each thread computes one element of the output matrix C. The inner loop performs multiply-accumulate along the K dimension:
C[row][col] = sum(A[row][k] * B[k][col]) for k = 0..K-1
Each output element requires reading one row of A (K elements) and one column of B (K elements) from global memory (HBM), totaling 2MNK memory accesses.
Naive GEMM: One Thread per Output Element
The problem is obvious: the same row of A is read repeatedly by all threads computing the same row of C. M×N threads each read independently — massive redundant memory access.
3. Optimization 1 — Tiling + Shared Memory
Core idea: divide the large matrices into BLOCK_SIZE×BLOCK_SIZE tiles, load one pair of tiles from HBM into shared memory at a time, and let all threads within the block share and reuse the data.
Global memory access drops from O(MNK) to O(MNK/BLOCK_SIZE) — a BLOCK_SIZE-fold reduction.
Step 1: Matrix split into Tile grid
Double __syncthreads() ensures: (1) all threads finish loading the tile before computation begins; (2) computation of the current tile finishes before loading the next one.
4. Optimization 2 — Thread Tiling (Multiple Elements per Thread)
Tiling solves the HBM bandwidth problem, but shared memory bandwidth can also become a bottleneck. When each thread computes only 1 element, the inner loop reads shared memory twice and performs 1 FMA per step — the compute:load ratio is only 0.5.
Thread tiling: each thread handles TM×TN output elements (e.g., 4x4=16 elements). After loading TM values from A and TN values from B into registers, TM×TN FMAs are produced — the compute:load ratio improves to TM+TNTM×TN.
1x1 Thread Tile: One thread computes one element
5. Optimization 3 — Vectorized Memory Access
GPU memory buses support 32/64/128-bit wide load instructions. Using float4 (128-bit) to load 4 floats at once reduces instruction scheduling overhead by 3/4 compared to 4 scalar loads.
This requires 128-bit address alignment. In tiling, tile starting addresses are typically naturally aligned.
6. Optimization 4 — Double Buffering Prefetch
While computing the current tile, simultaneously prefetch the next tile — overlapping memory access with compute latency.
Without Double Buffering: Serial Load and Compute
7. Optimization 5 — Tensor Core GEMM
Switching from CUDA Core to Tensor Core: using the WMMA (Warp Matrix Multiply-Accumulate) API, a single warp-level instruction completes a 16×16×16 matrix block multiply-accumulate.
Block-level tiling (shared memory) is still needed — Tensor Core only replaces the innermost compute unit.
Performance improvement from each optimization step (using 4096x4096 on H100 as reference):
From the naive implementation at less than 1% utilization to Tensor Core approaching 90% — the core philosophy remains: reduce memory access → increase data reuse → leverage dedicated hardware.
9. GEMM on Intel iGPU
The GEMM optimization approach on Intel Xe2 (Lunar Lake / Panther Lake) is exactly the same as CUDA — only the terminology and APIs differ:
Core mapping: shared memory → SLM, warp → sub-group, Tensor Core → XMX, wmma → joint_matrix. The essence of optimization doesn’t change: move data from far memory to near memory, and maximize reuse at the fastest storage level.