本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

Attention 变体:从 Sliding Window 到 MLA

Attention 变体:从 Sliding Window 到 MLA

更新于 2026-04-04

标准的 Multi-Head Attention 面临三大瓶颈:O(n2)O(n^2) 计算复杂度、KV cache 显存占用、长上下文处理。前面我们已经看过 GQA/MQA 如何通过共享 KV head 来减少缓存,这里进一步介绍其他重要的 attention 变体。

四个优化方向:

  • 稀疏化:减少每个 token 需要 attend 的范围(Sliding Window Attention)
  • 压缩 KV:减少每个位置存储的信息量(MLA)
  • 线性化:去掉 softmax,将 O(n2)O(n^2) 降到 O(n)O(n)(Linear Attention → GDN)
  • 混合架构:不同层用不同策略,取长补短(Hybrid Attention)

Sliding Window Attention

核心思想:每个 token 只 attend 前 ww 个 token,而非整个序列。

Attention(qi,K,V)=softmax(qiK[iw+1:i]Tdk)V[iw+1:i]\text{Attention}(q_i, K, V) = \text{softmax}\left(\frac{q_i \cdot K_{[i-w+1:i]}^T}{\sqrt{d_k}}\right) V_{[i-w+1:i]}

这将复杂度从 O(n2)O(n^2) 降到 O(nw)O(nw)。看起来损失了全局信息?其实不然 — 多层堆叠后,第 LL 层的有效感受野为 L×wL \times w。例如 Mistral 7B 有 32 层、窗口 w=4096w=4096,理论感受野覆盖 32×4096=13107232 \times 4096 = 131072 个 token。

Full Causal MaskSliding Window Mask (w=3)Key positionKey position计算量对比FullO(n²) = 36SWAO(nw) = 21 (42% less)

采用者: Mistral 7B (w=4096w=4096), Mixtral 8x7B (w=4096w=4096), Gemma 2(交替层使用)。

Sliding Window 还能和 Flash Attention 完美结合 — 窗口内的计算用 Flash Attention 的 tiling 策略高效完成,窗口外的直接跳过。

Hybrid Attention

不是所有层都需要 full attention — 将不同类型的 attention 混合使用,让每种类型发挥特长。

常见做法:

  • Gemma 2:偶数层 full attention + 奇数层 sliding window attention
  • Jamba (AI21):Attention 层 + Mamba (SSM) 层交替,1:3 的比例
  • Command-R (Cohere):部分层 full attention + 部分层 local attention
Hybrid Attention 层配置对比Gemma 2L0Full AttnL1SWAL2Full AttnL3SWAL4Full AttnL5SWAL6Full AttnL7SWAJambaL0AttentionL1MambaL2MambaL3AttentionL4MambaL5MambaL6AttentionL7MambaFull AttentionSliding WindowMamba (SSM)

设计选择的关键问题:full attention 层放哪里?比例怎么定?经验表明:靠近输出的层更需要全局信息(全局层放上面),靠近输入的层局部模式就够了。

Cross Attention

前面的 attention 变体都是 self-attention(Q/K/V 来自同一序列)。Cross attention 的核心区别是 Q 来自一个序列,K/V 来自另一个序列

Self-Attention
Self-Attention: Q, K, V 来自同一序列Input Sequence XQ = Wq·XK = Wk·XV = Wv·XAttention(Q, K, V)所有 Q, K, V 都从相同的输入 X 投影得到

典型场景

Encoder-Decoder 架构(翻译、摘要):

  • Decoder 的 hidden state 作为 Q
  • Encoder 的输出作为 K 和 V
  • 每个 decoder token “查询” encoder 的完整输入,决定关注输入的哪些部分
  • 采用者:T5, BART

多模态架构(图文理解):

  • 文本 decoder 的 token 作为 Q
  • Vision encoder 输出的图像 token 作为 K 和 V
  • 文本 token “查询” 图像 token,融合视觉信息
  • 采用者:Flamingo, LLaVA

和 self-attention 相比,cross attention 的 KV cache 行为不同:KV 来自 encoder 的固定输出,不会随着生成增长。

Multi-Latent Attention (MLA)

GQA 通过 共享 KV head 减少缓存,那能不能更激进?MLA 的思路是对 KV cache 做 低秩压缩,不存完整的 K 和 V,而是存一个低维的 compressed latent cKVc_{KV}

压缩过程:

cKV=WDKVh(h 是 hidden state, dmodeldcc_{KV} = W_{DKV} \cdot h \quad \text{(h 是 hidden state, } d_{model} \to d_c \text{)}

K=WUKcKV,V=WUVcKV(解压:dcdk,dvK = W_{UK} \cdot c_{KV}, \quad V = W_{UV} \cdot c_{KV} \quad \text{(解压:} d_c \to d_k, d_v \text{)}

只需缓存 cKVc_{KV}(维度远小于 K+V)。推理时再用 WUKW_{UK}WUVW_{UV} 解压。更妙的是,WUKW_{UK} 可以吸收进 WQW_Q 的矩阵乘法中,避免显式解压。

MLA: Multi-Latent Attention 数据流只缓存低维 c_KV(如 512 维),推理时解压为完整 K、VHidden State hd_modelCompress W_DKVd → d_cCache c_KVd_c (小!)W_UK → Kd_c → d_kW_UV → Vd_c → d_vAttention Output缓存点 — 大幅节省显存推理优化:W_UK 可吸收进 W_Q,避免显式解压 K — 进一步减少计算采用者:DeepSeek-V2, DeepSeek-V3, DeepSeek-R1

下面的计算器可以对比不同配置下 MHA、GQA、MLA 的 KV cache 大小:

KV Cache 大小对比 (FP16, seq_len=4096)head_dim = d_model / num_heads = 4096 / 32 = 128MHA: 2 × 32 heads × 128 dim × 4096 seq × 2B = 64.0 MBGQA: 2 × 8 kv_heads × 128 dim × 4096 seq × 2B = 16.0 MBMLA: 512 latent_dim × 4096 seq × 2B = 4.0 MBMHA64.0 MB (100%)GQA (8h)16.0 MB (25%)MLA4.0 MB (6%)

以 DeepSeek-V2 为例(dmodel=5120d_{model}=5120, 128 heads, dc=512d_c=512):MLA 的 KV cache 只有标准 MHA 的约 5%

采用者: DeepSeek-V2, DeepSeek-V3, DeepSeek-R1

Linear Attention 与 Gated Delta Net

前面的变体都保留了 softmax — 只是减少计算范围(SWA)或压缩缓存(MLA)。Linear Attention 走了更激进的路线:直接去掉 softmax,从根本上消除 O(n2)O(n^2) 复杂度

核心思想:去掉 softmax,改变计算顺序

标准 Attention 的计算是:

Attn(Q,K,V)=softmax(QKT/d)n×nV\text{Attn}(Q,K,V) = \underbrace{\text{softmax}(QK^T / \sqrt{d})}_{n \times n} \cdot V

必须先算 QKTQK^Tn×nn \times n 矩阵),再 softmax,再乘 VV。softmax 是逐行归一化操作,它阻止了矩阵乘法的结合律 — 你不能先算 KTVK^TV 再乘 QQ,因为 softmax 卡在中间。

Linear Attention (Katharopoulos et al., 2020) 的关键洞察:用特征映射 ϕ\phi 替代 softmax,使得:

Attn(Q,K,V)=ϕ(Q)(ϕ(K)TV)d×d\text{Attn}(Q,K,V) = \phi(Q) \cdot \underbrace{(\phi(K)^T \cdot V)}_{d \times d}

ϕ(K)TV\phi(K)^T Vd×dd \times d 矩阵(dd 是 head dimension,通常 64-128),与序列长度 nn 无关。当 ndn \gg d 时,这从 O(n2d)O(n^2 d) 降到了 O(nd2)O(nd^2) — 真正的线性复杂度。

标准 Attentionsoftmax( Q · KT/ √d ) · Vsoftmax 阻止结合律 — 必须先算 QKᵀ(n×n)512×512 = 262,144O(n²d)Linear Attentionφ(Q) · ( φ(K)T· V )去掉 softmax → 先算 Kᵀ V(d×d),再乘 Q64×64 = 4,096O(nd²)中间矩阵缩小 64×
序列长度 n512

为什么 φ 能解锁结合律

softmax 之所以阻止结合律,是因为它做了两件事:(1) exp() 保证非负,(2) 逐行归一化(除以行和)。归一化耦合了同一行的所有元素 — 计算 softmax(qiTk1)\text{softmax}(q_i^T k_1) 需要知道 qiTk2,qiTk3,...,qiTknq_i^T k_2, q_i^T k_3, ..., q_i^T k_n。你不能独立算一个元素,必须先算完整行的 QKTQK^Tn×nn \times n)。

ϕ\phi 的策略:用独立的逐元素变换替代耦合的归一化。将 softmax 核 sim(q,k)=exp(qTk)/Z\text{sim}(q, k) = \exp(q^T k) / Z 替换为核函数分解:

sim(q,k)=ϕ(q)Tϕ(k)\text{sim}(q, k) = \phi(q)^T \phi(k)

ϕ\phi 独立作用于每个 qqkk(不需要知道其他 key 的值),而内积天然满足结合律。展开第 ii 个输出(含归一化):

oi=ϕ(qi)Tjϕ(kj)vjTd×d, 与 i 无关ϕ(qi)Tjϕ(kj)d×1, 与 i 无关o_i = \frac{\phi(q_i)^T \overbrace{\sum_j \phi(k_j) v_j^T}^{d \times d,\ \text{与 } i \text{ 无关}}}{\phi(q_i)^T \underbrace{\sum_j \phi(k_j)}_{d \times 1,\ \text{与 } i \text{ 无关}}}

分子和分母中的求和都与查询位置 ii 无关,可以预先算好一次(O(nd)O(nd)),然后每个查询只做 O(d)O(d) 的向量乘法。

φ 的选择需要满足两个条件:(1) 输出非负 — 确保注意力权重非负,(2) 逐元素独立 — 不能像 softmax 那样耦合同行其他元素。常见选择包括 ϕ(x)=elu(x)+1\phi(x) = \text{elu}(x) + 1(Katharopoulos 2020 原始选择)和 ϕ(x)=ReLU(x)\phi(x) = \text{ReLU}(x)。理论上,Random Fourier Features 可以近似 softmax 的 exp 核,但计算较贵。

φ 效果不如 softmax 的根本原因:softmax 的 exp + 归一化天然产生尖锐、稀疏的注意力分布(大值指数放大,小值压到接近零),让模型能集中关注最相关的少数 token。简单的 φ 没有这种”赢者通吃”效果,所有 key 的贡献差异不大。这就是后续工作(RetNet/DeltaNet/GDN)加衰减、delta rule、门控的本质动机 — 弥补 φ 无法替代 softmax 选择性的缺陷。

RNN 形式:固定大小的状态

去掉 softmax 后,Linear Attention 可以写成 RNN 递推形式:

St=St1+ϕ(kt)ϕ(vt)T,ot=ϕ(qt)StS_t = S_{t-1} + \phi(k_t) \phi(v_t)^T, \quad o_t = \phi(q_t) \cdot S_t

StS_t 是一个 d×dd \times d状态矩阵,压缩了所有历史信息。推理时不需要 KV cache(随序列增长),只需维护这个固定大小的状态。

这个形式和 状态空间模型 (SSM) 惊人地相似:

Linear AttentionSSM / Mamba
状态更新St=St1+ktvtTS_t = S_{t-1} + k_t v_t^Tht=Aht1+Bxth_t = Ah_{t-1} + Bx_t
输出ot=qtSto_t = q_t S_tyt=Chty_t = Ch_t
状态大小d×dd \times d 矩阵(固定)NN 维向量(固定)
推理复杂度O(1)O(1) per tokenO(1)O(1) per token

两者本质相同:用固定大小的状态压缩历史,线性递推更新Mamba-2 的 SSD 框架 从数学上严格证明了这个等价性 — 结构化 SSM 就是一种带衰减的 linear attention。

代价:softmax 不是白删的

softmax 提供了稀疏、尖锐的注意力分布 — 它让模型能够集中关注少数关键 token(比如在长文本中精确定位某个名字)。去掉 softmax 后,注意力分布变得”平坦”,所有 token 的贡献差异减小。

这导致纯 Linear Attention 在需要精确检索的任务上(如 copying、in-context learning)远弱于标准 Attention。这和 SSM/Mamba 面临的限制本质相同 — 固定大小的状态无法精确记住任意长度序列中的特定信息。详见 Hybrid 架构:为什么纯 SSM 不够

从累加到纠错:状态更新的演进

基础 Linear Attention 的状态只做累加(St=St1+ktvtTS_t = S_{t-1} + k_t v_t^T),永远不遗忘。后续工作的核心改进就是让状态更新更智能:

2020基础 Linear Attention2023RetNet(指数衰减)2024DeltaNet(Delta Rule 纠错)2024GDN(门控 + Delta Rule)
状态更新
Sₜ = Sₜ₋₁ + kₜ vₜᵀ
输出: oₜ = qₜ Sₜ
关键变化: + kₜ vₜᵀ(直接累加)
核心创新
去掉 softmax,利用矩阵乘法结合律将 O(n²) 降到 O(n)。可写成 RNN 递推形式,推理时只需维护固定大小的 d×d 状态矩阵
局限 / 现状
状态只累加、不遗忘 — 所有历史信息混在一起,注意力分布"平坦",建模能力远弱于 softmax attention
论文: Katharopoulos et al.

演进脉络:

  1. Basic Linear Attention (2020):去掉 softmax,建立 RNN 形式。但状态只累加不遗忘,性能差距大
  2. RetNet (2023, MSR):加入指数衰减因子 γ\gamma,旧信息每步自动淡化。但 γ\gamma 是固定超参数
  3. DeltaNet (2024, Yang et al.):引入 delta rule — 不盲目累加,而是先查状态中”已有什么”(ktTSt1k_t^T S_{t-1}),只写入差值(“还缺什么”)。这是一种联想记忆的在线学习规则
  4. Gated Delta Net (GDN) (2024, ICLR 2025):结合两种互补机制 — αt\alpha_t 门控(来自 Mamba2)实现选择性遗忘 + delta rule 实现精准写入。论文证明 GDN 在语言建模、上下文检索、长序列理解上超越 Mamba2

与 Mamba 的关系

Linear Attention 和 SSM/Mamba 是同一思想的两种表述 — 前者从 Attention 出发(去掉 softmax),后者从控制论出发(状态空间递推),最终殊途同归。

GDN 直接体现了这种融合:它在论文标题中就写明 “Improving Mamba2 with Delta Rule” — 把 Mamba2 的门控机制和 linear attention 的 delta rule 组合在一起。GDN 的 Hybrid 架构(GDN 层 + sliding window attention 层交替)和 Jamba(Mamba 层 + attention 层交替)的思路完全一致。

Linear Attention 家族和 SSM/Mamba 家族正在快速融合。理解一方就理解了另一方的核心思想。详细的 SSM/Mamba 原理参见 状态空间模型与 Mamba,混合架构设计参见 Hybrid 架构

对比总结

方法计算复杂度KV Cache核心思想
Full MHAO(n²d)2 × n_heads × d_head × seq每个 head 独立 Q/K/V,完整注意力
GQAO(n²d)2 × n_kv_heads × d_head × seq多个 Q head 共享 KV head,减少 KV 缓存
Sliding WindowO(nwd)2 × n_heads × d_head × w每个 token 只 attend 前 w 个,堆叠扩大感受野
Cross AttentionO(n·m·d)2 × n_heads × d_head × m (encoder)Q 来自 decoder,KV 来自 encoder/视觉
MLAO(n²d)latent_dim × seq (极小)低秩压缩 KV cache,存 compressed latent
Hybrid混合分层不同混合不同 attention 类型(full + SWA / SSM)

选型指南:

  • 长上下文 + 低延迟 → Sliding Window(Mistral 方案)
  • 极致 KV cache 压缩 → MLA(DeepSeek 方案)
  • 跨模态 / Encoder-Decoder → Cross Attention
  • 线性复杂度 + 固定状态 → Linear Attention / GDN(替代 KV cache)
  • 平衡方案 → Hybrid(Gemma 2 方案,或 GDN + SWA 交替)
  • 通用 KV 节省 → GQA(目前最主流的折中选择)