GEMM (General Matrix Multiply) 是 LLM 训练和推理中占比最高的计算操作。本文从最简单的 naive 实现出发,逐步添加优化 — tiling、thread tiling、向量化访存、双缓冲、Tensor Core — 直到接近 cuBLAS 的理论峰值性能。
1. 为什么 GEMM 是 LLM 的核心
Transformer 的每一层包含多个矩阵乘法: QKV 投影、attention score、output projection、FFN 的两个线性层。对于一个典型的 LLM (如 hidden_dim=4096),超过 90% 的计算量来自这些 GEMM 操作。
GEMM 的计算量为 2MNK FLOPs (M、N、K 分别是矩阵 A、B、C 的维度)。对于大矩阵,arithmetic intensity (计算密度) 远高于 roofline 的 ridge point,属于 compute-bound — 这意味着理论上可以充分利用 GPU 的计算单元。
2. Naive 实现 — 基线
最简单的 CUDA GEMM: 每个线程计算输出矩阵 C 的一个元素。内循环沿 K 维做乘加:
C[row][col] = sum(A[row][k] * B[k][col]) for k = 0..K-1
每个输出元素需要从全局内存 (HBM) 读取 A 的一行 (K 个元素) 和 B 的一列 (K 个元素),总共 2MNK 次内存访问。
Naive GEMM: 一个线程算一个输出元素
问题很明显: 同一行 A 数据被 C 的同一行的所有线程重复读取。M×N 个线程各自独立读取 — 大量冗余访存。
3. 优化 1 — Tiling + Shared Memory
核心思想: 把大矩阵切成 BLOCK_SIZE×BLOCK_SIZE 的 tile,每次从 HBM 加载一对 tile 到 shared memory,block 内所有线程共享复用。
全局内存访问从 O(MNK) 降到 O(MNK/BLOCK_SIZE) — 减少 BLOCK_SIZE 倍。
Step 1: 矩阵切分为 Tile 网格
双重 __syncthreads() 确保: (1) 所有线程加载完 tile 后再开始计算; (2) 计算完当前 tile 后再加载下一个。
4. 优化 2 — Thread Tiling (每线程多元素)
Tiling 解决了 HBM 带宽问题,但 shared memory 带宽也会成为瓶颈。每线程只算 1 个元素时,内循环每步读 2 次 shared memory、做 1 次 FMA — compute:load 比只有 0.5。
Thread tiling: 每个线程负责 TM×TN 个输出元素 (如 4x4=16 个)。A 的 TM 个值和 B 的 TN 个值加载到寄存器后,产生 TM×TN 次 FMA — compute:load 比提升到 TM+TNTM×TN。
1x1 Thread Tile: 一个线程算一个元素
5. 优化 3 — 向量化访存
GPU 内存总线支持 32/64/128-bit 宽度的 load 指令。使用 float4 (128-bit) 一次加载 4 个 float,比 4 次标量 load 减少 3/4 的指令调度开销。
要求数据地址 128-bit 对齐。在 tiling 中,tile 的起始地址通常自然对齐。
6. 优化 4 — 双缓冲 Prefetch
计算当前 tile 时,同时预加载下一个 tile — 重叠访存和计算延迟。
无双缓冲: Load 和 Compute 串行
7. 优化 5 — Tensor Core GEMM
从 CUDA Core 切换到 Tensor Core: 使用 WMMA (Warp Matrix Multiply-Accumulate) API,一条 warp 级指令完成 16×16×16 的矩阵块乘加。
仍然需要 block 级 tiling (shared memory) — Tensor Core 只是替换了最内层的计算单元。
WMMA 的三步流程: load_matrix_sync (shared memory → register fragment) → mma_sync (Tensor Core 执行) → store_matrix_sync (写回)。
Step 1: load_matrix_sync — 加载 Fragment
FP16 输入 + FP32 累加 = 精度损失可控 + 吞吐量提升 4-8 倍。
8. 性能阶梯总结
每步优化带来的性能提升 (以 H100 上 4096x4096 为参考):
从 naive 的不到 1% 利用率,到 Tensor Core 接近 90% — 核心思路始终是: 减少内存访问 → 提高数据复用 → 利用专用硬件。
9. Intel iGPU 上的 GEMM
Intel Xe2 (Lunar Lake / Panther Lake) 的 GEMM 优化思路与 CUDA 完全相同 — 只是术语和 API 不同:
核心映射: shared memory → SLM、warp → sub-group、Tensor Core → XMX、wmma → joint_matrix。优化的本质不变: 数据从远存搬到近存,在最快的存储层级上最大化复用。