本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

MQA 与 GQA

MQA 与 GQA

更新于 2026-04-23

简介:为什么 MHA 的 KV Cache 是瓶颈

在上一篇文章中,我们了解了 Multi-Head Attention (MHA) 的工作原理:hh 个 head 各自拥有独立的 WQW^QWKW^KWVW^V 投影矩阵,在不同的子空间中并行计算注意力。

MHA 在训练时表现出色 — 所有 token 可以并行处理。但在推理(自回归生成)时,一个严重的效率瓶颈浮现:KV Cache

什么是 KV Cache

在自回归生成中,每个新 token 需要与所有之前的 token 做注意力计算。为了避免重复计算,我们会把之前 token 的 Key 和 Value 缓存起来 — 这就是 KV Cache

对于标准 MHA,每一层的 KV Cache 大小为:

KV Cache per layer=2×h×S×dk×sizeof(dtype)\text{KV Cache per layer} = 2 \times h \times S \times d_k \times \text{sizeof(dtype)}

其中 hh 是 head 数,SS 是序列长度,dkd_k 是每头维度,factor 2 代表 K 和 V 两份缓存。

以 LLaMA-2 70B 为例(h=64h = 64dk=128d_k = 128L=80L = 80 层),生成长度 S=4096S = 4096 的序列时,每个请求的 KV Cache:

2×64×4096×128×80×2 bytes10.7 GB (FP16)2 \times 64 \times 4096 \times 128 \times 80 \times 2\text{ bytes} \approx 10.7 \text{ GB (FP16)}

这比模型参数本身的显存占用还要大!当需要同时服务多个用户(batch serving)时,KV Cache 会迅速耗尽 GPU 显存,成为吞吐量的核心瓶颈。

核心观察:KV Cache 的大小与 head 数 hh 成正比。如果我们能减少需要缓存的 KV head 数量,就能直接缩小 KV Cache。

KV Cache 内存瓶颈对比Llama-2 70B — 显存占用对比FP16, seq_len=4096, batch_size=32MHA 标准配置模型权重~140GBKV Cache (batch=32)~80GBGQA 优化配置模型权重~140GBKV Cache (batch=32)~10GBKV Cache 在大 batch 下可接近模型权重的显存占用!

MQA:共享 KV 的极致方案

Multi-Query Attention (MQA) 由 Noam Shazeer 在 2019 年提出(“Fast Transformer Decoding: One Write-Head is All You Need”),是最激进的 KV 缩减方案。

核心思想

MQA 的改动很简单:所有 query head 共享同一组 Key 和 Value

  • 每个 head 仍然有自己独立的 WiQW_i^Q(Query 投影各不相同)
  • 但所有 head 共用一个 WKW^K 和一个 WVW^V

数学表达:

headi=Attention(XWiQ,XWK,XWV)\text{head}_i = \text{Attention}(X W_i^Q, \, X W^K, \, X W^V)

注意 WKW^KWVW^V 不再有下标 ii — 它们在所有 head 之间是共享的。

MQA 共享 KV 结构Multi-Query Attention (MQA)QueryQ0Q1Q2Q3Q4Q5Q6Q7KVKV8 Query Heads → 1 共享 KV Head内存 ÷8但质量可能下降

KV Cache 缩减

由于只有一组 KV,缓存大小从:

2×h×S×dkMQA2×1×S×dk2 \times h \times S \times d_k \quad \xrightarrow{\text{MQA}} \quad 2 \times 1 \times S \times d_k

缩减了 hh 倍!对于 h=64h = 64 的模型,这意味着 KV Cache 缩小为原来的 164\frac{1}{64}

代价

MQA 的缩减是极端的,也带来了明显的代价:

  • 质量下降:所有 head 被迫在同一个 KV 子空间中计算注意力,丧失了 MHA 中”不同 head 关注不同模式”的能力
  • 训练不稳定:从零训练 MQA 模型时,收敛可能更困难
  • 论文报告”only minor quality degradation”,但实际在下游任务上,尤其是需要细粒度推理的任务中,质量损失可能更明显

GQA:分组查询的折中方案

Grouped-Query Attention (GQA) 由 Ainslie 等人在 2023 年提出,是 MHA 和 MQA 之间的优雅折中。

核心思想

hh 个 query head 分为 gg 个组,每组共享一对 KV head

  • g=hg = h(每组一个 head)→ 退化为标准 MHA
  • g=1g = 1(所有 head 一组)→ 退化为 MQA
  • 1<g<h1 < g < h → GQA,折中方案

数学表达(设第 ii 个 query head 属于第 ig/h\lfloor i \cdot g / h \rfloor 组):

headi=Attention(XWiQ,XWgroup(i)K,XWgroup(i)V)\text{head}_i = \text{Attention}(X W_i^Q, \, X W_{\text{group}(i)}^K, \, X W_{\text{group}(i)}^V)
GQA 分组 Head 结构Grouped-Query Attention (GQA)32 Q Heads → 8 KV Groups (4:1 ratio)KVG0KVG1KVG2KVG3KVG4KVG5KVG6KVG7QKVMHA: Q:KV = 32:32GQA: Q:KV = 32:8MQA: Q:KV = 32:1

KV Cache 缩减

2×h×S×dkGQA2×g×S×dk2 \times h \times S \times d_k \quad \xrightarrow{\text{GQA}} \quad 2 \times g \times S \times d_k

缩减了 h/gh / g 倍。例如 h=64h = 64g=8g = 8 时,KV Cache 缩小为原来的 18\frac{1}{8}

关键创新:Uptraining

GQA 论文的另一个重要贡献是提出了从已有的 MHA checkpoint uptraining 到 GQA 的方法:

  1. 将原始 MHA 中每组内的多个 KV head 的权重取均值,初始化 GQA 的共享 KV head
  2. 只需原始预训练计算量的约 5% 即可完成转换
  3. 转换后的模型质量接近原始 MHA,推理速度接近 MQA

这意味着不需要从头训练 — 可以把现有的 MHA 模型高效地转换为 GQA 模型。

Step 1: 8 个 KV Head 权重矩阵

原始 MHA 模型有 8 个独立的 KV head,每个都有自己的权重矩阵。

KV1
KV2
KV3
KV4
KV5
KV6
KV7
KV8

结构对比图:MHA vs MQA vs GQA

MHA — 每个 Q 对应独立 KV

标准多头注意力:4 个 Q head 各自拥有独立的 KV head。

MHA一对一Q₁Q₂Q₃Q₄KV₁KV₂KV₃KV₄KV heads = h = 4

上图展示了三种注意力机制的 head-to-KV 映射关系(以 h=4h = 4 为例):

  • MHA:4 个 Query head 各自对应 1 个独立的 KV head(共 4 个 KV head)
  • GQAg=2g = 2):4 个 Query head 分为 2 组,每组共享 1 个 KV head(共 2 个 KV head)
  • MQA:4 个 Query head 全部共享 1 个 KV head(共 1 个 KV head)

KV Cache 内存分析:具体数值对比

让我们用真实模型的参数来计算 KV Cache 的内存占用。假设序列长度 S=4096S = 4096,FP16(2 bytes per element):

模型层数 LLhhdkd_kKV headsKV Cache / 请求
假设 MHA-70B806412864 (MHA)10.7 GB
LLaMA-2 70B (GQA)806412881.3 GB
假设 MQA-70B80641281 (MQA)0.17 GB

计算公式:

KV Cache=2×KV heads×S×dk×L×2 bytes\text{KV Cache} = 2 \times \text{KV heads} \times S \times d_k \times L \times 2\text{ bytes}

以 LLaMA-2 70B 的 GQA 配置为例:

2×8×4096×128×80×2=1.34 GB2 \times 8 \times 4096 \times 128 \times 80 \times 2 = 1.34 \text{ GB}

对比:从 MHA 的 10.7 GB 降到 GQA 的 1.3 GB,缩减了约 8 倍h/g=64/8=8h / g = 64 / 8 = 8),而如果用 MQA 则可以缩减 64 倍到仅 0.17 GB。

h=64, kv_heads=8, L=80, d_k=128
0.0GB80.0GB160.0GB240.0GB320.0GB2561K4K16K64K
MHA (320.0 GB @ 128K) GQA (40.0 GB) MQA (5.00 GB)

Batch Serving 的影响

KV Cache 缩减对批处理推理的影响更加显著。假设 GPU 有 40 GB 剩余显存用于 KV Cache:

方案KV Cache / 请求可并发请求数
MHA10.7 GB~3
GQA (8 组)1.3 GB~30
MQA0.17 GB~235

GQA 将并发能力提升了约 10 倍 — 这对 LLM 服务的成本和延迟有决定性影响。

40 GB
MHA
4 个并发(10.00 GB/req)
GQA
32 个并发(1.25 GB/req)
MQA
256 个并发(0.16 GB/req)

基于 LLaMA-2 70B 参数 (L=80, h=64, d_k=128, GQA kv_heads=8), FP16

质量与性能的 Trade-off

减少 KV head 数量本质上是一种信息压缩:强制多个 query head 在同一个 KV 子空间中寻找不同的注意力模式。

为什么 GQA 质量损失很小

  1. 冗余性:研究发现 MHA 中相邻 head 的 KV 投影往往高度相似 — 许多 head 学到了冗余的 KV 表示
  2. Query 多样性保留:GQA 保留了所有 query head 的独立性,只是共享了 KV 空间。Query 投影仍然可以在共享的 KV 空间中学习不同的注意力模式
  3. Uptraining 有效性:从 MHA checkpoint 通过均值池化初始化 + 少量继续训练,可以高效恢复质量

GQA 论文报告,使用原始预训练计算量约 5% 的 uptraining,GQA 模型在大多数基准测试中的表现接近原始 MHA 模型,同时推理速度接近 MQA。

速度提升的来源

KV Cache 缩减带来的速度提升主要来自两个方面:

  1. 内存带宽:自回归解码是 memory-bandwidth bound 的操作。KV Cache 更小意味着每步生成需要加载的数据更少,直接提升生成速度
  2. 内存容量:更小的 KV Cache 允许更大的 batch size,提升 GPU 利用率和整体吞吐量

实际应用

GQA 已成为当前主流大语言模型的标准配置:

模型Query HeadsKV Heads组比例 (h/g)注意力类型
LLaMA-2 7B32321:1MHA
LLaMA-2 13B40401:1MHA
LLaMA-2 70B6488:1GQA
LLaMA-3 8B3284:1GQA
LLaMA-3 70B6488:1GQA
Mistral 7B3284:1GQA
Gemini 1.0 Pro1MQA

值得注意的趋势

  • LLaMA-2 系列:只有最大的 70B 模型使用 GQA,较小的 7B 和 13B 仍使用标准 MHA。这说明在当时,KV Cache 瓶颈主要是大模型面临的问题
  • LLaMA-3 系列:所有尺寸(包括 8B)都采用 GQA,反映了 GQA 已被证明在各种规模下都有效
  • Mistral 7B:即使是 7B 规模也使用 GQA(4:1),结合滑动窗口注意力进一步优化推理效率
  • Gemini 1.0 Pro:使用更激进的 MQA 方案,所有 query head 共享单一 KV head
  • 行业共识:GQA 已成为新模型的默认选择,8 个 KV head 是一个常见的配置

PyTorch 实现要点

GQA 的实现只需要在标准 MHA 基础上做少量修改:

# GQA: g 个 KV head, h 个 query head, 每组 h//g 个 query 共享一个 KV
class GroupedQueryAttention(nn.Module):
    def __init__(self, H, h, g, d_k):
        super().__init__()
        self.h = h      # query head 数
        self.g = g      # KV head 数 (group 数)
        self.d_k = d_k
        
        self.W_q = nn.Linear(H, h * d_k)    # h 个 query head
        self.W_k = nn.Linear(H, g * d_k)    # g 个 KV head
        self.W_v = nn.Linear(H, g * d_k)    # g 个 KV head
        self.W_o = nn.Linear(h * d_k, H)
    
    def forward(self, x):
        B, S, _ = x.shape
        
        Q = self.W_q(x).view(B, S, self.h, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, S, self.g, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, S, self.g, self.d_k).transpose(1, 2)
        
        # 关键步骤:将 KV head 扩展以匹配 query head 数
        # 每个 KV head 被复制 h//g 次
        repeats = self.h // self.g
        K = K.repeat_interleave(repeats, dim=1)  # (B, h, S, d_k)
        V = V.repeat_interleave(repeats, dim=1)  # (B, h, S, d_k)
        
        # 标准 attention 计算
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
        weights = F.softmax(scores, dim=-1)
        output = weights @ V
        
        output = output.transpose(1, 2).contiguous().view(B, S, -1)
        return self.W_o(output)

关键区别在于:W_kW_v 的输出维度是 g×dkg \times d_k(而非 h×dkh \times d_k),然后通过 repeat_interleave 将每个 KV head 复制 h/gh/g 次以匹配 query head 数量。注意这个复制操作不增加 KV Cache 的大小 — 缓存的仍然只有 gg 个 KV head。

总结

概念说明
MHA 瓶颈KV Cache 随 head 数线性增长,限制推理效率和并发能力
MQA所有 query head 共享一对 KV,KV Cache 缩减 hh 倍,但质量损失明显
GQAhh 个 query head 分 gg 组,每组共享 KV,折中方案
KV Cache 缩减MHA: 2hSdk2hSd_k → GQA: 2gSdk2gSd_k → MQA: 2Sdk2Sd_k
Uptraining从 MHA checkpoint 用约 5% 计算量转换为 GQA
行业趋势GQA 已成为 LLaMA-3、Mistral 等主流模型的标准配置

核心直觉:MHA 中大量 KV head 存在信息冗余。GQA 通过让一组 query head 共享同一对 KV head,在几乎不损失质量的前提下,将 KV Cache 缩减数倍,从而大幅提升推理效率和服务并发能力。这就像一个团队开会 — 不需要每个人都带一份完整的会议资料,几个人共享一份即可,节省的是桌面空间(显存),而不影响讨论质量(模型能力)。