本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

KV Cache 原理

KV Cache 原理

更新于 2026-04-23

简介:KV Cache 是推理加速的关键

在前一篇文章中,我们了解了 LLM 推理分为 Prefill 和 Decode 两个阶段。Decode 阶段是逐 token 自回归生成的 — 每一步只生成一个新 token。

这里有一个关键问题:在标准 Self-Attention 中,每个新 token 的 Query 需要与序列中所有之前 token 的 Key 做点积,然后用 所有之前 token 的 Value 做加权求和。如果每步都从头重新计算所有 token 的 K 和 V,计算量会随着生成长度二次增长 — 这是不可接受的。

KV Cache 的核心思想很简单:把已经计算过的 Key 和 Value 向量缓存起来,每步只计算新 token 的 K 和 V,然后追加到缓存中。 这将 Decode 阶段的计算复杂度从 O(N2)O(N^2) 降低到 O(N)O(N)(对线性层而言),是现代 LLM 推理系统的标配。

KV Cache 加速原理无 KV Cachet=11×1t=22×2t=33×3t=44×4总计: O(N²)有 KV Cachet=11t=22t=33t=44总计: O(N)O(N²) → O(N) 线性投影加速

问题:没有 Cache 时的 O(N2)O(N^2) 计算浪费

朴素自回归生成

考虑一个生成序列长度为 NN 的 Decode 过程。如果不使用 KV Cache,在第 tt 步生成第 tt 个 token 时,模型需要:

  1. 将前 tt 个 token 全部输入 Transformer
  2. 对每一层,计算所有 tt 个 tokenQQKKVV
  3. 执行完整的 Attention 计算
  4. 只取最后一个位置的输出,用于预测下一个 token

总计算量中,仅线性投影部分就需要:

t=1Nt×(2×3×d2)=3d2×N(N+1)O(N2d2)\sum_{t=1}^{N} t \times (2 \times 3 \times d^2) = 3d^2 \times N(N+1) \approx O(N^2 d^2)

其中 3d23d^2 对应 Q、K、V 三个投影矩阵的计算量(每个 token 经过 WQW^QWKW^KWVW^V)。

浪费在哪里?

关键观察:tt 步中前 t1t-1 个 token 的 K 和 V 与第 t1t-1 步完全相同。因为这些 token 的输入没有变化(Decoder 使用因果遮罩,前面 token 的表示不受后面 token 影响),它们的 K、V 投影结果也完全一样。

换句话说,每一步都在重复计算已经算过的 Key 和 Value — 这是纯粹的计算浪费。

生成 "world" (第 1 步)

第一步没有重复计算。

Helloworldgen world总计算: 2 · 浪费: 0 (0%)
新计算 重复计算

KV Cache 机制:缓存和追加

核心思路

KV Cache 的机制非常直接:

  1. Prefill 阶段:处理完整 prompt,计算所有 token 的 K、V,并将结果保存到缓存
  2. Decode 每一步
    • 只将新的 1 个 token 送入模型
    • 计算新 token 的 qqkkvv(向量,不是矩阵)
    • kkvv 追加到缓存末尾
    • qq 与完整的 K Cache 做 Attention
    • 用 Attention 权重对完整的 V Cache 加权求和

数学表示

设第 tt 步时:

  • 新 token 的 Query: qt=xtWQR1×dkq_t = x_t W^Q \in \mathbb{R}^{1 \times d_k}
  • 新 token 的 Key: kt=xtWKR1×dkk_t = x_t W^K \in \mathbb{R}^{1 \times d_k}
  • 新 token 的 Value: vt=xtWVR1×dvv_t = x_t W^V \in \mathbb{R}^{1 \times d_v}

缓存更新:

Kcacheconcat(Kcache, kt)Rt×dkK_{\text{cache}} \leftarrow \text{concat}(K_{\text{cache}},\ k_t) \in \mathbb{R}^{t \times d_k} Vcacheconcat(Vcache, vt)Rt×dvV_{\text{cache}} \leftarrow \text{concat}(V_{\text{cache}},\ v_t) \in \mathbb{R}^{t \times d_v}

Attention 计算:

scores=qtKcacheT/dkR1×t\text{scores} = q_t \cdot K_{\text{cache}}^T / \sqrt{d_k} \in \mathbb{R}^{1 \times t} outputt=softmax(scores)VcacheR1×dv\text{output}_t = \text{softmax}(\text{scores}) \cdot V_{\text{cache}} \in \mathbb{R}^{1 \times d_v}

计算量对比

无 KV Cache有 KV Cache
tt 步线性投影tt 个 token 全部重算仅 1 个新 token
线性投影 FLOPs/步2×3×t×d22 \times 3 \times t \times d^22×3×d22 \times 3 \times d^2
总线性投影 FLOPsO(N2d2)O(N^2 d^2)O(Nd2)O(Nd^2)
Attention FLOPs/步O(t2dk)O(t^2 d_k)O(tdk)O(t \cdot d_k)
代价无额外内存需要缓存 K, V

KV Cache 用内存换时间,将线性投影的总计算量从 O(N2)O(N^2) 降到 O(N)O(N)

深入理解:为什么只缓存 K 和 V?

为什么只缓存 K 和 VQ vs K vs V:缓存价值对比类型Decode 第 t 步的向量Q仅当前 token用完即弃不缓存K+new所有历史 token持续复用缓存V+new所有历史 token持续复用缓存Attention = Q_new × [K_cache]ᵀ → softmax → × [V_cache]Q 每步都是新的;K, V 是不断积累的历史

上面介绍了 KV Cache 的机制,但一个自然的问题是:Attention 计算涉及 Q、K、V 三个矩阵,为什么只缓存 K 和 V,不缓存 Q? 进一步地,为什么不直接缓存 QKTQK^T 的注意力分数?

为什么不缓存 Q?

关键在于 Q 是”用完即弃”的,而 K、V 是”持续被需要”的

在 decode 的第 tt 步,新 token 的 query qtq_t 只有一个用途:与所有已缓存的 K 做点积,计算 attention score。一旦当前步的 attention 计算完成,qtq_t 就不再被任何后续步骤使用。

反过来看 K 和 V 的使用模式:

  • ktk_t 被未来所有步骤需要:第 t+1,t+2,...,Nt+1, t+2, ..., N 步的新 query 都需要与 ktk_t 做点积
  • vtv_t 被未来所有步骤需要:每一步的 attention 加权求和都要包含 vtv_t
对象可缓存?原因
K未来每步的新 Q 都需要和所有历史 K 做点积
V未来每步的 attention 加权求和都需要所有历史 V
Q每步只用当前 token 的 Q,用完即弃,无跨步复用价值

用一个比喻来说:K 和 V 是黑板上的笔记 — 每个新同学(新 token)进来都要翻阅整块黑板;Q 是每个同学自己提的问题 — 问完就结束了,后面的同学不需要看你的问题。

为什么不缓存 QKTQK^T 注意力分数?

另一个直觉是:既然第 tt 步已经计算了 qtq_tk1,...,ktk_1, ..., k_t 的点积分数,能否把这些分数存下来供后面使用?

答案是不能,因为每一步的 query 都不同:

  • tt 步:scores=[qtk1, qtk2, ..., qtkt]\text{scores} = [q_t \cdot k_1,\ q_t \cdot k_2,\ ...,\ q_t \cdot k_t]
  • t+1t+1 步:scores=[qt+1k1, qt+1k2, ..., qt+1kt+1]\text{scores} = [q_{t+1} \cdot k_1,\ q_{t+1} \cdot k_2,\ ...,\ q_{t+1} \cdot k_{t+1}]

虽然 k1,...,ktk_1, ..., k_t 没有变化,但 qt+1qtq_{t+1} \neq q_t,所以每一项分数都必须用新的 query 重新计算。Attention score 是 Q 和 K 的二元函数 — K 端不变(所以缓存 K),但 Q 端每步都是全新的,因此函数值(score)没有任何一项可以复用。

Decode 的递推结构:动态规划视角

理解了”只缓存 K 和 V”之后,可以从更高的层面理解 KV Cache 的本质 — 它实际上是自回归推理的 memoization(记忆化)

在多层 Transformer 中,decode 第 tt 步的完整数据流如下:

新 token x_t(embedding)


┌─ Layer 1 ──────────────────────────────────────────────────┐
│  q_t, k_t, v_t = x_t · W_Q, W_K, W_V                     │
│  KV Cache[1] ← append(k_t, v_t)                           │
│  attn_out = softmax(q_t · K_cache^T / √d_k) · V_cache     │
│  h_t^(1) = FFN(out_proj(attn_out) + x_t)                  │
└────────────────────────────────────────────────────────────┘
  │ h_t^(1)

┌─ Layer 2 ──────────────────────────────────────────────────┐
│  q_t, k_t, v_t = h_t^(1) · W_Q, W_K, W_V                 │
│  KV Cache[2] ← append(k_t, v_t)                           │
│  attn_out = softmax(q_t · K_cache^T / √d_k) · V_cache     │
│  h_t^(2) = FFN(out_proj(attn_out) + h_t^(1))              │
└────────────────────────────────────────────────────────────┘
  │ ... 逐层向上 ...

┌─ Layer L → h_t^(L) ───────────────────────────────────────┐
└────────────────────────────────────────────────────────────┘


lm_head(h_t^(L)) → logits → 预测下一个 token

注意其中的递推依赖:第 \ell 层的 K/V 来自上一层的输出:

kt()=ht(1)WK(),vt()=ht(1)WV()k_t^{(\ell)} = h_t^{(\ell-1)} \cdot W_K^{(\ell)}, \quad v_t^{(\ell)} = h_t^{(\ell-1)} \cdot W_V^{(\ell)}

ht(1)h_t^{(\ell-1)} 本身需要第 1\ell-1 层的 attention 计算才能得到。这意味着:

  • 第 1 层的 K/V 可以直接从 token embedding 计算
  • 但第 2 层及以后的 K/V 都依赖前一层的 attention 输出 — 这是一个逐层递推的过程

Prefill 阶段不能跳过 attention 也是同样的原因:即使我们的目标是”填充 KV Cache”,深层的 K/V 也必须通过浅层的完整 attention 计算才能得到。Prefill 本质上是一次完整的前向传播,KV Cache 是这个过程的副产品

这个递推结构天然适合用**动态规划(Dynamic Programming)**的框架来理解:

DP 概念KV Cache 对应
子问题每个 (layer, position) 处的 K/V 向量
递推关系kt()=f(ht(1))k_t^{(\ell)} = f(h_t^{(\ell-1)}),深层依赖浅层 attention 输出
无后效性因果 mask 保证:位置 ii 的 K/V 只取决于 1,...,i1, ..., i,不受未来 token 影响
Memoization将算过的 K/V 存入 cache,后续步骤直接查表复用
复杂度优化线性投影总量从 O(N2d2)O(N^2 d^2) 降至 O(Nd2)O(N d^2)

其中无后效性是 KV Cache 能够成立的关键前提:因果 mask 确保了任何历史位置的 K/V 一旦算出就永远不会改变 — 无论未来生成什么新 token。这和 DP 中”已求解的子问题不受后续决策影响”的性质完全一致。

交互演示

下面的演示模拟了一个 5 步 Decode 过程。初始缓存包含 Prefill 阶段的 2 个 token 的 K/V。每一步中,观察:

  • 新 token 的 Q 向量如何与整个 K Cache 计算注意力分数
  • K Cache 和 V Cache 如何在每步增长一行(绿色高亮)
Decode 第 1 步 — 生成 t₃

当前 token t₃ 的 Query 向量与 KV Cache 中的所有 Key 做点积,然后将新的 K、V 追加到缓存。缓存从 2 行增长到 3 行。

新 token 的 Query 向量
q (t₃)
d₁
d₂
d₃
0.63
0.06
0.25
(1, 3)
K Cache(追加后)
d₁
d₂
d₃
t₁
0.70
0.55
0.40
t₂
0.92
0.89
-0.88
t₃
-0.45
0.08
0.55
(3, 3)
V Cache(追加后)
d₁
d₂
d₃
t₁
-0.97
-0.19
0.97
t₂
0.09
-0.49
0.58
t₃
0.09
0.83
-0.68
(3, 3)
注意力分数(q · K^T / √d_k → softmax)
缩放后分数
t₁
t₂
t₃
0.33
0.24
-0.08
注意力权重
t₁
t₂
t₃
0.39
0.35
0.26
输出 = 权重 × V Cache
output (t₃)
d₁
d₂
d₃
-0.32
-0.03
0.40
(1, 3)
绿色高亮行是本步新追加的 K/V 缓存行。无需重新计算之前 token 的 K、V — 它们已经在缓存中。KV Cache 大小: 23 行。

内存占用分析

KV Cache 是 LLM 推理中显存占用的主要来源之一,理解其内存公式至关重要。

单层单头的 KV Cache

对于一层 Transformer 的一个注意力头,缓存序列长度为 SS 时:

单头 KV Cache=2×S×dk×dtype_size\text{单头 KV Cache} = 2 \times S \times d_k \times \text{dtype\_size}

其中因子 2 对应 K 和 V 各一份。

完整模型的 KV Cache

对于完整模型,设 LL 为 Transformer 层数,nhn_h 为注意力头数,dkd_k 为每个头的维度(dk=dmodel/nhd_k = d_{\text{model}} / n_h),则:

KV Cache=2×L×nh×S×dk×dtype_size\text{KV Cache} = 2 \times L \times n_h \times S \times d_k \times \text{dtype\_size}

由于 nh×dk=dmodeln_h \times d_k = d_{\text{model}},可以简化为:

KV Cache=2×L×dmodel×S×dtype_size\boxed{\text{KV Cache} = 2 \times L \times d_{\text{model}} \times S \times \text{dtype\_size}}

实际数值

LLaMA-2 7B 为例(L=32L=32dmodel=4096d_{\text{model}}=4096,FP16):

序列长度 SSKV Cache 大小
5122×32×4096×512×2=2562 \times 32 \times 4096 \times 512 \times 2 = 256 MB
20482×32×4096×2048×2=12 \times 32 \times 4096 \times 2048 \times 2 = 1 GB
40962×32×4096×4096×2=22 \times 32 \times 4096 \times 4096 \times 2 = 2 GB

LLaMA-2 70B 为例(L=80L=80dmodel=8192d_{\text{model}}=8192,FP16):

序列长度 SSKV Cache 大小
20482×80×8192×2048×2=5.02 \times 80 \times 8192 \times 2048 \times 2 = 5.0 GiB
40962×80×8192×4096×2=10.02 \times 80 \times 8192 \times 4096 \times 2 = 10.0 GiB

注意:以上使用标准 MHA 公式计算(nh=64n_h = 64),即假设所有 head 独立缓存 KV。实际上 LLaMA-2 70B 使用 GQA(nkv=8n_{kv} = 8),真实 KV Cache 仅为上表的 1/81/8(S=4096 时约 1.25 GiB)。详见下方 GQA/MQA 小节

可以看到,对于大模型和长序列,KV Cache 的显存占用可以达到甚至超过模型参数本身。这也是为什么 KV Cache 管理和压缩是推理优化的核心方向。

单请求 KV Cache
1.000 GB
1 个并发总占用
1.00 GB
占 GPU 显存
1.3%
A100 80GB: 80 GB 总显存
公式: 2 × L(32) × kv_heads(32) × S(2,048) × d_k(128) × 2B

Batch 场景

当服务多个并发请求时,KV Cache 随 batch size BB 线性增长:

Total KV Cache=B×2×L×dmodel×S×dtype_size\text{Total KV Cache} = B \times 2 \times L \times d_{\text{model}} \times S \times \text{dtype\_size}

例如 LLaMA-2 7B,S=2048S=2048B=32B=32:KV Cache 总量 = 32×1 GB=3232 \times 1\text{ GB} = 32 GB。仅 KV Cache 就可能耗尽整块 GPU 的显存。

Cache 管理:PagedAttention 简介

传统 KV Cache 实现面临一个严重问题:显存碎片化

传统方式的问题

在最简单的实现中,系统为每个请求预分配最大序列长度的缓存空间。例如,最大长度 2048,即使实际只用了 100 个 token,也需要预留 2048 个位置的空间。这导致:

  • 内部碎片:预分配但未使用的空间被浪费
  • 外部碎片:请求结束后释放的空间难以被新请求完整利用
  • 显存利用率低:实际有效利用率可能只有 20-40%

PagedAttention

PagedAttention(来自 vLLM,论文 “Efficient Memory Management for Large Language Model Serving with PagedAttention”,arXiv 2309.06180)借鉴了操作系统的虚拟内存分页思想:

  1. 将 KV Cache 空间划分为固定大小的 Block(类似内存页,通常 16 或 32 个 token)
  2. 每个请求维护一个 Block Table(类似页表),记录逻辑块到物理块的映射
  3. 按需分配 Block:生成新 token 时,只有当前 Block 满了才分配新的 Block
  4. 请求结束后,Block 可以被回收并分配给新请求

核心优势:

  • 消除碎片:固定大小的 Block 分配和回收不会产生碎片
  • 动态增长:不需要预分配最大长度,按需增长
  • 显存利用率高:实际利用率可以接近 100%
  • 支持复杂调度:如 beam search 中多个候选可以共享前缀的 KV Cache Block
传统预分配
GPU 显存
PagedAttention
GPU 显存
Request A 分配空间
Request A Request B 内部碎片 外部碎片

Continuous Batching

与 PagedAttention 配合的另一项重要技术是 Continuous Batching(连续批处理):

  • 传统 Static Batching:一个 batch 中所有请求必须等最长的那个完成,才能处理新请求
  • Continuous Batching:某个请求完成后,立即插入新请求填补空位,不等待其他请求

这使得 GPU 始终保持高利用率,显著提升推理吞吐量。vLLM、TensorRT-LLM 等主流推理引擎都采用了 PagedAttention + Continuous Batching 的组合。

Continuous Batching 时间线3 个 GPU slot,请求动态进出,完成即释放Slot 0Slot 1Slot 201234567891011121314ABCDED 到达E 到达Prefill(深色)Decode(浅色)新请求到达请求 A/C 完成后 slot 立即被 D/E 填入 — 无空闲等待,GPU 利用率最大化

与 GQA/MQA 的关系:减少 KV Cache 的另一种方式

GQA 减少 KV CacheKV Cache Size = f(n_kv_heads)2 × n_kv_heads × d_head × seq_len × n_layers × batchMHA(32 KV heads)32/32 = 1×GQA(8 KV heads)8/32 = ¼MQA(1 KV heads)1/32 = 1/32减少 n_kv_heads → 直接缩减 KV Cache 大小

KV Cache 的内存瓶颈促使研究者从模型架构层面寻找解决方案。

回顾 KV Cache 公式

KV Cache=2×L×nh×dkdmodel×S×dtype_size\text{KV Cache} = 2 \times L \times \underbrace{n_h \times d_k}_{d_{\text{model}}} \times S \times \text{dtype\_size}

其中 nh×dkn_h \times d_k 对应所有注意力头的 KV 维度。如果能减少需要缓存的 KV 头数,就能直接缩小 KV Cache。

MQA (Multi-Query Attention)

MQA 让所有注意力头共享同一组 K 和 V,仅 Q 保持多头。KV Cache 缩小为:

KV CacheMQA=2×L×1×dk×S×dtype_size\text{KV Cache}_{\text{MQA}} = 2 \times L \times 1 \times d_k \times S \times \text{dtype\_size}

相比 MHA 缩小了 nhn_h 倍(例如 nh=32n_h=32 时缩小 32 倍)。

GQA (Grouped-Query Attention)

GQA 是 MHA 和 MQA 的折中:将 nhn_h 个 Q 头分成 gg 组,每组共享一组 K、V。KV Cache 为:

KV CacheGQA=2×L×g×dk×S×dtype_size\text{KV Cache}_{\text{GQA}} = 2 \times L \times g \times d_k \times S \times \text{dtype\_size}

相比 MHA 缩小了 nh/gn_h/g 倍。LLaMA-2 70B 使用 GQA(nh=64n_h=64g=8g=8),KV Cache 缩小为原来的 1/81/8

效果对比

方法KV 头数KV Cache 相对大小质量影响
MHAnhn_h1×1\times (基准)最优
GQAgg (1<g<nh1 < g < n_h)g/nhg/n_h极小损失
MQA111/nh1/n_h略有损失

GQA 在几乎不影响模型质量的前提下,大幅减少了 KV Cache,已被 LLaMA-2/3、Mistral 等主流模型采用。更详细的 GQA/MQA 介绍请参考 MQA 与 GQA 一文。

总结

概念说明
KV Cache缓存已计算的 K、V 向量,避免 Decode 时重复计算
无 Cache 的代价每步重算所有 KV,总计算量 O(N2d2)O(N^2 d^2)
有 Cache 的加速每步仅算 1 个新 token 的 KV 并追加,总量 O(Nd2)O(Nd^2)
内存公式2×L×dmodel×S×dtype_size2 \times L \times d_{\text{model}} \times S \times \text{dtype\_size}
PagedAttention借鉴虚拟内存分页,消除显存碎片,提升利用率
GQA/MQA从架构层面减少 KV 头数,直接缩小 Cache
核心权衡KV Cache 用显存计算,是经典的时空权衡

核心直觉:KV Cache 就像一个不断增长的”笔记本”。自回归生成时,每个新 token 只需要在笔记本末尾添加一行自己的 K 和 V,然后翻阅整个笔记本来决定关注什么。如果没有这个笔记本,每一步都要把之前所有 token 的 K 和 V 从头算一遍 — 这就像每次上课都要从第一节课重新复习到最新内容,显然是巨大的浪费。