本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

GEMM 优化 — 从 Naive 到极致

GEMM 优化 — 从 Naive 到极致

更新于 2026-04-04

GEMM (General Matrix Multiply) 是 LLM 训练和推理中占比最高的计算操作。本文从最简单的 naive 实现出发,逐步添加优化 — tiling、thread tiling、向量化访存、双缓冲、Tensor Core — 直到接近 cuBLAS 的理论峰值性能。

1. 为什么 GEMM 是 LLM 的核心

Transformer 的每一层包含多个矩阵乘法: QKV 投影、attention score、output projection、FFN 的两个线性层。对于一个典型的 LLM (如 hidden_dim=4096),超过 90% 的计算量来自这些 GEMM 操作。

Transformer Block 中的 GEMM 操作蓝色标注 = 矩阵乘法 (占计算量 90%+)。S = seq_len, H = hidden_dimInput Embedding (S×H)Multi-Head AttentionQKV ProjectionS×H · H×3HAttention ScoreS×H · H×SAttention OutputS×S · S×HOutput ProjectionS×H · H×HAdd & LayerNormFeed-Forward NetworkFFN UpS×H · H×4HFFN DownS×4H · 4H×HAdd & LayerNorm每个 Transformer 层包含 6 个 GEMM — 它们决定了推理和训练的计算时间

GEMM 的计算量为 2MNK2MNK FLOPs (M、N、K 分别是矩阵 A、B、C 的维度)。对于大矩阵,arithmetic intensity (计算密度) 远高于 roofline 的 ridge point,属于 compute-bound — 这意味着理论上可以充分利用 GPU 的计算单元。

C(4096x4096) = A(4096x4096) * B(4096x4096)FLOPs = 2MNK = 137.4GMemory = 4(MK + KN + MN) = 201.3M bytesArithmetic Intensity = FLOPs / Bytes = 682.7 FLOPs/byteRoofline 位置 (H100 FP32 CUDA Core)ridge: 20AI682.7Memory-boundCompute-boundCompute-bound: 计算量充足,可以充分利用 Tensor Core典型 LLM: M=batch*seq, K=N=hidden_dim (4096+) → AI 通常 > 100 → Compute-boundGEMM 优化的目标: 让实际 FLOPS 接近 peak (FP32: 67T, FP16 TC: 990T)

2. Naive 实现 — 基线

最简单的 CUDA GEMM: 每个线程计算输出矩阵 C 的一个元素。内循环沿 K 维做乘加:

C[row][col] = sum(A[row][k] * B[k][col]) for k = 0..K-1

每个输出元素需要从全局内存 (HBM) 读取 A 的一行 (K 个元素) 和 B 的一列 (K 个元素),总共 2MNK2MNK 次内存访问。

Naive GEMM: 一个线程算一个输出元素
C = A x B (4x4) — 每个线程计算 C 的一个元素高亮线程负责计算 C[1][2]A (4x4)2130102103122103xB (4x4)1021310202101131=C (4x4)C00C01C02C03C10C11C12C13C20C21C22C23C30C31C32C33Naive kernel: 每个线程的工作C[row][col] = 0;for (k = 0; k < K; k++) C[row][col] += A[row][k] * B[k][col]; // 2 global reads per iteration每个输出元素需要 2K 次全局内存读取 → 总共 2MNK 次 → 严重 memory-bound

问题很明显: 同一行 A 数据被 C 的同一行的所有线程重复读取。M×NM \times N 个线程各自独立读取 — 大量冗余访存。

3. 优化 1 — Tiling + Shared Memory

核心思想: 把大矩阵切成 BLOCK_SIZE×BLOCK_SIZE\text{BLOCK\_SIZE} \times \text{BLOCK\_SIZE} 的 tile,每次从 HBM 加载一对 tile 到 shared memory,block 内所有线程共享复用。

全局内存访问从 O(MNK)O(MNK) 降到 O(MNK/BLOCK_SIZE)O(MNK / \text{BLOCK\_SIZE}) — 减少 BLOCK_SIZE 倍。

Step 1: 矩阵切分为 Tile 网格
Tiling: 把大矩阵切成 BLOCK_SIZE x BLOCK_SIZE 的小块每个 Block 负责计算 C 的一个 tile,沿 K 维遍历 A/B 的 tile 对A (M x K)xB (K x N)=C (M x N)Tiling 策略1. 每个 CUDA Block 对应 C 的一个 tile (BLOCK_SIZE x BLOCK_SIZE 线程)2. 计算 C[tile_r][tile_c] 需要 A 的第 tile_r 行所有 tile x B 的第 tile_c 列所有 tile3. 外循环: for t = 0 to K/BLOCK_SIZE — 每次加载一对 tile 到 Shared Memory4. 关键: 每个 tile 从 HBM 只加载一次,被 BLOCK_SIZE 个线程共享复用

双重 __syncthreads() 确保: (1) 所有线程加载完 tile 后再开始计算; (2) 计算完当前 tile 后再加载下一个。

4. 优化 2 — Thread Tiling (每线程多元素)

Tiling 解决了 HBM 带宽问题,但 shared memory 带宽也会成为瓶颈。每线程只算 1 个元素时,内循环每步读 2 次 shared memory、做 1 次 FMA — compute:load 比只有 0.5。

Thread tiling: 每个线程负责 TM×TNTM \times TN 个输出元素 (如 4x4=16 个)。A 的 TM 个值和 B 的 TN 个值加载到寄存器后,产生 TM×TNTM \times TN 次 FMA — compute:load 比提升到 TM×TNTM+TN\frac{TM \times TN}{TM + TN}

1x1 Thread Tile: 一个线程算一个元素
Naive Tiling: 每个线程计算 C 的 1 个元素每次内循环迭代: 从 shared memory 读 2 个值 (A 和 B),做 1 次 FMAC tile (BLOCK_SIZE x BLOCK_SIZE)Thread(2,3)低效: 计算/访存比 = 1:2内循环每次: 读 As[ty][k] + 读 Bs[k][tx] = 2 次 shared memory 读计算: 1 次 FMA (fused multiply-add)Compute:Load = 1 FMA / 2 reads = 0.5 — shared memory 带宽成为瓶颈
Thread Tile = 4 x 4 = 16 个输出元素A[4]B[4]C[4x4]每次内循环 k 步:从 shared mem 读: TM + TN = 8 FMA 计算: TM x TN = 16 Compute:Load = 4x4 / (4+4) = 2.00良好寄存器使用: C 累加器 (4x4=16) + A 片段 (4) + B 片段 (4) = 24 个寄存器寄存器越多 → thread tile 越大 → 比率越高,但 occupancy 可能下降 (trade-off)

5. 优化 3 — 向量化访存

GPU 内存总线支持 32/64/128-bit 宽度的 load 指令。使用 float4 (128-bit) 一次加载 4 个 float,比 4 次标量 load 减少 3/4 的指令调度开销。

向量化访存: float vs float4标量加载 (4 条指令)LDG.32 R0, [addr + 0]32bLDG.32 R1, [addr + 4]32bLDG.32 R2, [addr + 8]32bLDG.32 R3, [addr + 12]32b向量加载 (1 条指令)LDG.128 R0:R3, [addr]一次读取 128 bits = 4 个 float128b对比指令数4 条 LDG.321 条 LDG.128总传输量4 x 32b = 128b1 x 128b = 128b指令发射开销4 个调度槽1 个调度槽float4 tmp = *reinterpret_cast<float4*>(&A[row * K + k]); // 128-bit aligned load

要求数据地址 128-bit 对齐。在 tiling 中,tile 的起始地址通常自然对齐。

6. 优化 4 — 双缓冲 Prefetch

计算当前 tile 时,同时预加载下一个 tile — 重叠访存和计算延迟。

无双缓冲: Load 和 Compute 串行
串行流水线: 计算必须等加载完成每个 tile: 加载到 shared memory → __syncthreads() → 计算 → __syncthreads() → 下一个时间 →LoadLoad T0Load T1Load T2Load T3ComputeCompute T0Compute T1Compute T2Compute T3idleidleidleidleidleidleidleidle总时间 = 4 x (Load + Compute) — 一半时间在空闲!问题: 加载和计算不能重叠加载 tile 到 shared memory 后才能计算 → 计算完才能加载下一个 tile原因: 只有一块 shared memory buffer,加载和计算操作的是同一块内存解决方案: 用两块 buffer — 一块加载新数据,同时另一块用于计算(或者用寄存器预加载: 先读到寄存器,计算完当前 tile 后再写入 shared memory)

7. 优化 5 — Tensor Core GEMM

从 CUDA Core 切换到 Tensor Core: 使用 WMMA (Warp Matrix Multiply-Accumulate) API,一条 warp 级指令完成 16×16×1616 \times 16 \times 16 的矩阵块乘加。

仍然需要 block 级 tiling (shared memory) — Tensor Core 只是替换了最内层的计算单元。

Tensor Core GEMM 的多级 Tiling从 Grid 到 Tensor Core,每级 tile 缩小到适合当前硬件层级Level 1: Grid Tile每个 CUDA Block 负责 C 的一个 BM x BN 区域 (如 128 x 128)决定: 每个 Block 的工作量、shared memory 需求Level 2: Warp TileBlock 内每个 Warp 负责 C 的 WM x WN 区域 (如 32 x 64)决定: Block 内 Warp 的分工、寄存器分配Level 3: MMA Instruction Tile每个 Warp 内层循环用 wmma::mma_sync 做 16 x 16 x 16 (或 m16n8k16) 的矩阵块乘这一步由 Tensor Core 硬件执行 — 一条指令完成整块乘加Tensor Core: D(16x16) = A(16x16) * B(16x16) + C(16x16)wmma::mma_sync<16,16,16,half> (底层映射到多条 PTX mma.sync 指令)典型尺寸 (H100 HGEMM)Grid tile (Block)128 x 128shared memory 中Warp tile32 x 64register fragment 中MMA tile (WMMA)16 x 16 x 16Tensor Core 一条 warp 级指令

WMMA 的三步流程: load_matrix_sync (shared memory → register fragment) → mma_sync (Tensor Core 执行) → store_matrix_sync (写回)。

Step 1: load_matrix_sync — 加载 Fragment
WMMA Step 1: 从 Shared Memory 加载 Fragment32 个线程协作将矩阵块加载到各自的寄存器中 (fragment)Shared MemoryAs[16x16] + Bs[16x16]Warp (32 threads) — 各线程的寄存器A fragmentwmma::matrix_a<16,16,16,half>B fragmentwmma::matrix_b<16,16,16,half>C accumulatorwmma::accumulator<float>// 声明 fragment (每个线程持有矩阵的一部分)wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;wmma::load_matrix_sync(a_frag, &As[warp_row * 16], 16); // shared → registerFragment 分布16x16 = 256 个元素由 32 个线程分持: 每个线程寄存器中持有 8 个元素具体哪个线程持有哪些元素由硬件决定 — 对程序员不透明 (只能通过 wmma API 操作)

FP16 输入 + FP32 累加 = 精度损失可控 + 吞吐量提升 4-8 倍。

8. 性能阶梯总结

每步优化带来的性能提升 (以 H100 上 4096x4096 为参考):

GEMM 优化性能阶梯 (H100, 4096x4096)每步优化的 GFLOPS (前 5 步: FP32 SGEMM, 最后一步: FP16 HGEMM) — 悬停查看详情cuBLAS ~65K (97%)Naive400 (0.6%)+ Block Tiling8K (12%)+ Thread Tile25K (37%)+ Vec Load35K (52%)+ Double Buffer45K (67%)Tensor Core (FP16)60K (~90%)悬停各优化阶段查看详情优化核心: 减少内存访问 → 提高数据复用 → 利用专用硬件 (Tensor Core) → 接近理论峰值

从 naive 的不到 1% 利用率,到 Tensor Core 接近 90% — 核心思路始终是: 减少内存访问 → 提高数据复用 → 利用专用硬件

9. Intel iGPU 上的 GEMM

Intel Xe2 (Lunar Lake / Panther Lake) 的 GEMM 优化思路与 CUDA 完全相同 — 只是术语和 API 不同:

CUDA GEMM vs Intel GEMM: Tiling 层级对照思路完全相同 — 只是 API 和硬件单元名称不同CUDA (NVIDIA)SYCL / DPC++ (Intel)Grid 级Block tile: BM x BN(Shared Memory)Work-group tile: BM x BN(SLM / Shared Local Memory)Warp/Sub-group 级Warp tile: WM x WN(Register fragment)Sub-group tile: WM x WN(GRF / Register)指令级wmma::mma_sync(Tensor Core 16x16x16)joint_matrix_mad(XMX systolic array)数据类型FP16 in → FP32 acc(mma.sync.f32.f16)FP16/BF16 in → FP32 acc(dpas / XMX)编程 APICUDA wmma / mma.sync(PTX / CUTLASS)SYCL joint_matrix(或 ESIMD intrinsics)核心相同: HBM → SLM/smem (block tile) → Register (warp/sub-group tile) → 矩阵硬件 (TC/XMX)

核心映射: shared memory → SLM、warp → sub-group、Tensor Core → XMX、wmma → joint_matrix。优化的本质不变: 数据从远存搬到近存,在最快的存储层级上最大化复用。