MQA 与 GQA
更新于 2026-04-23
简介:为什么 MHA 的 KV Cache 是瓶颈
在上一篇文章中,我们了解了 Multi-Head Attention (MHA) 的工作原理: 个 head 各自拥有独立的 、、 投影矩阵,在不同的子空间中并行计算注意力。
MHA 在训练时表现出色 — 所有 token 可以并行处理。但在推理(自回归生成)时,一个严重的效率瓶颈浮现:KV Cache。
什么是 KV Cache
在自回归生成中,每个新 token 需要与所有之前的 token 做注意力计算。为了避免重复计算,我们会把之前 token 的 Key 和 Value 缓存起来 — 这就是 KV Cache。
对于标准 MHA,每一层的 KV Cache 大小为:
其中 是 head 数, 是序列长度, 是每头维度,factor 2 代表 K 和 V 两份缓存。
以 LLaMA-2 70B 为例(,, 层),生成长度 的序列时,每个请求的 KV Cache:
这比模型参数本身的显存占用还要大!当需要同时服务多个用户(batch serving)时,KV Cache 会迅速耗尽 GPU 显存,成为吞吐量的核心瓶颈。
核心观察:KV Cache 的大小与 head 数 成正比。如果我们能减少需要缓存的 KV head 数量,就能直接缩小 KV Cache。
MQA:共享 KV 的极致方案
Multi-Query Attention (MQA) 由 Noam Shazeer 在 2019 年提出(“Fast Transformer Decoding: One Write-Head is All You Need”),是最激进的 KV 缩减方案。
核心思想
MQA 的改动很简单:所有 query head 共享同一组 Key 和 Value。
- 每个 head 仍然有自己独立的 (Query 投影各不相同)
- 但所有 head 共用一个 和一个
数学表达:
注意 和 不再有下标 — 它们在所有 head 之间是共享的。
KV Cache 缩减
由于只有一组 KV,缓存大小从:
缩减了 倍!对于 的模型,这意味着 KV Cache 缩小为原来的 。
代价
MQA 的缩减是极端的,也带来了明显的代价:
- 质量下降:所有 head 被迫在同一个 KV 子空间中计算注意力,丧失了 MHA 中”不同 head 关注不同模式”的能力
- 训练不稳定:从零训练 MQA 模型时,收敛可能更困难
- 论文报告”only minor quality degradation”,但实际在下游任务上,尤其是需要细粒度推理的任务中,质量损失可能更明显
GQA:分组查询的折中方案
Grouped-Query Attention (GQA) 由 Ainslie 等人在 2023 年提出,是 MHA 和 MQA 之间的优雅折中。
核心思想
将 个 query head 分为 个组,每组共享一对 KV head。
- 当 (每组一个 head)→ 退化为标准 MHA
- 当 (所有 head 一组)→ 退化为 MQA
- 当 → GQA,折中方案
数学表达(设第 个 query head 属于第 组):
KV Cache 缩减
缩减了 倍。例如 、 时,KV Cache 缩小为原来的 。
关键创新:Uptraining
GQA 论文的另一个重要贡献是提出了从已有的 MHA checkpoint uptraining 到 GQA 的方法:
- 将原始 MHA 中每组内的多个 KV head 的权重取均值,初始化 GQA 的共享 KV head
- 只需原始预训练计算量的约 5% 即可完成转换
- 转换后的模型质量接近原始 MHA,推理速度接近 MQA
这意味着不需要从头训练 — 可以把现有的 MHA 模型高效地转换为 GQA 模型。
原始 MHA 模型有 8 个独立的 KV head,每个都有自己的权重矩阵。
结构对比图:MHA vs MQA vs GQA
标准多头注意力:4 个 Q head 各自拥有独立的 KV head。
上图展示了三种注意力机制的 head-to-KV 映射关系(以 为例):
- MHA:4 个 Query head 各自对应 1 个独立的 KV head(共 4 个 KV head)
- GQA():4 个 Query head 分为 2 组,每组共享 1 个 KV head(共 2 个 KV head)
- MQA:4 个 Query head 全部共享 1 个 KV head(共 1 个 KV head)
KV Cache 内存分析:具体数值对比
让我们用真实模型的参数来计算 KV Cache 的内存占用。假设序列长度 ,FP16(2 bytes per element):
| 模型 | 层数 | KV heads | KV Cache / 请求 | ||
|---|---|---|---|---|---|
| 假设 MHA-70B | 80 | 64 | 128 | 64 (MHA) | 10.7 GB |
| LLaMA-2 70B (GQA) | 80 | 64 | 128 | 8 | 1.3 GB |
| 假设 MQA-70B | 80 | 64 | 128 | 1 (MQA) | 0.17 GB |
计算公式:
以 LLaMA-2 70B 的 GQA 配置为例:
对比:从 MHA 的 10.7 GB 降到 GQA 的 1.3 GB,缩减了约 8 倍(),而如果用 MQA 则可以缩减 64 倍到仅 0.17 GB。
Batch Serving 的影响
KV Cache 缩减对批处理推理的影响更加显著。假设 GPU 有 40 GB 剩余显存用于 KV Cache:
| 方案 | KV Cache / 请求 | 可并发请求数 |
|---|---|---|
| MHA | 10.7 GB | ~3 |
| GQA (8 组) | 1.3 GB | ~30 |
| MQA | 0.17 GB | ~235 |
GQA 将并发能力提升了约 10 倍 — 这对 LLM 服务的成本和延迟有决定性影响。
基于 LLaMA-2 70B 参数 (L=80, h=64, d_k=128, GQA kv_heads=8), FP16
质量与性能的 Trade-off
减少 KV head 数量本质上是一种信息压缩:强制多个 query head 在同一个 KV 子空间中寻找不同的注意力模式。
为什么 GQA 质量损失很小
- 冗余性:研究发现 MHA 中相邻 head 的 KV 投影往往高度相似 — 许多 head 学到了冗余的 KV 表示
- Query 多样性保留:GQA 保留了所有 query head 的独立性,只是共享了 KV 空间。Query 投影仍然可以在共享的 KV 空间中学习不同的注意力模式
- Uptraining 有效性:从 MHA checkpoint 通过均值池化初始化 + 少量继续训练,可以高效恢复质量
GQA 论文报告,使用原始预训练计算量约 5% 的 uptraining,GQA 模型在大多数基准测试中的表现接近原始 MHA 模型,同时推理速度接近 MQA。
速度提升的来源
KV Cache 缩减带来的速度提升主要来自两个方面:
- 内存带宽:自回归解码是 memory-bandwidth bound 的操作。KV Cache 更小意味着每步生成需要加载的数据更少,直接提升生成速度
- 内存容量:更小的 KV Cache 允许更大的 batch size,提升 GPU 利用率和整体吞吐量
实际应用
GQA 已成为当前主流大语言模型的标准配置:
| 模型 | Query Heads | KV Heads | 组比例 (h/g) | 注意力类型 |
|---|---|---|---|---|
| LLaMA-2 7B | 32 | 32 | 1:1 | MHA |
| LLaMA-2 13B | 40 | 40 | 1:1 | MHA |
| LLaMA-2 70B | 64 | 8 | 8:1 | GQA |
| LLaMA-3 8B | 32 | 8 | 4:1 | GQA |
| LLaMA-3 70B | 64 | 8 | 8:1 | GQA |
| Mistral 7B | 32 | 8 | 4:1 | GQA |
| Gemini 1.0 Pro | — | 1 | — | MQA |
值得注意的趋势:
- LLaMA-2 系列:只有最大的 70B 模型使用 GQA,较小的 7B 和 13B 仍使用标准 MHA。这说明在当时,KV Cache 瓶颈主要是大模型面临的问题
- LLaMA-3 系列:所有尺寸(包括 8B)都采用 GQA,反映了 GQA 已被证明在各种规模下都有效
- Mistral 7B:即使是 7B 规模也使用 GQA(4:1),结合滑动窗口注意力进一步优化推理效率
- Gemini 1.0 Pro:使用更激进的 MQA 方案,所有 query head 共享单一 KV head
- 行业共识:GQA 已成为新模型的默认选择,8 个 KV head 是一个常见的配置
PyTorch 实现要点
GQA 的实现只需要在标准 MHA 基础上做少量修改:
# GQA: g 个 KV head, h 个 query head, 每组 h//g 个 query 共享一个 KV
class GroupedQueryAttention(nn.Module):
def __init__(self, H, h, g, d_k):
super().__init__()
self.h = h # query head 数
self.g = g # KV head 数 (group 数)
self.d_k = d_k
self.W_q = nn.Linear(H, h * d_k) # h 个 query head
self.W_k = nn.Linear(H, g * d_k) # g 个 KV head
self.W_v = nn.Linear(H, g * d_k) # g 个 KV head
self.W_o = nn.Linear(h * d_k, H)
def forward(self, x):
B, S, _ = x.shape
Q = self.W_q(x).view(B, S, self.h, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, S, self.g, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, S, self.g, self.d_k).transpose(1, 2)
# 关键步骤:将 KV head 扩展以匹配 query head 数
# 每个 KV head 被复制 h//g 次
repeats = self.h // self.g
K = K.repeat_interleave(repeats, dim=1) # (B, h, S, d_k)
V = V.repeat_interleave(repeats, dim=1) # (B, h, S, d_k)
# 标准 attention 计算
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
weights = F.softmax(scores, dim=-1)
output = weights @ V
output = output.transpose(1, 2).contiguous().view(B, S, -1)
return self.W_o(output)
关键区别在于:W_k 和 W_v 的输出维度是 (而非 ),然后通过 repeat_interleave 将每个 KV head 复制 次以匹配 query head 数量。注意这个复制操作不增加 KV Cache 的大小 — 缓存的仍然只有 个 KV head。
总结
| 概念 | 说明 |
|---|---|
| MHA 瓶颈 | KV Cache 随 head 数线性增长,限制推理效率和并发能力 |
| MQA | 所有 query head 共享一对 KV,KV Cache 缩减 倍,但质量损失明显 |
| GQA | 个 query head 分 组,每组共享 KV,折中方案 |
| KV Cache 缩减 | MHA: → GQA: → MQA: |
| Uptraining | 从 MHA checkpoint 用约 5% 计算量转换为 GQA |
| 行业趋势 | GQA 已成为 LLaMA-3、Mistral 等主流模型的标准配置 |
核心直觉:MHA 中大量 KV head 存在信息冗余。GQA 通过让一组 query head 共享同一对 KV head,在几乎不损失质量的前提下,将 KV Cache 缩减数倍,从而大幅提升推理效率和服务并发能力。这就像一个团队开会 — 不需要每个人都带一份完整的会议资料,几个人共享一份即可,节省的是桌面空间(显存),而不影响讨论质量(模型能力)。