本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

Multi-Head Attention

Multi-Head Attention

更新于 2026-04-23

简介:为什么单头不够

在上一篇文章中,我们详细拆解了 Scaled Dot-Product Attention 的计算过程。但在实际的 Transformer 中,注意力并不是只计算一次 — 而是并行计算多次,这就是 Multi-Head Attention

为什么需要多个头?考虑这个句子:

“The animal didn’t cross the street because it was too tired.”

对于 “it” 这个 token,我们需要同时关注多种不同的关系:

  • 指代关系:it → animal(语义共指)
  • 句法关系:it → was(主谓搭配)
  • 因果关系:it → because(逻辑连接)

如果只有一个注意力头,它的 softmax 输出是一个概率分布 — 只能产生一种权重模式。这意味着模型必须将所有不同类型的关系”混合”到同一组权重中,严重限制了表达能力。

Multi-Head Attention 的核心思想:让不同的注意力头在不同的子空间中独立计算注意力,每个头可以关注不同类型的模式,最后将结果拼接起来。

单头 vs 多头注意力对比Single Head只有一种注意力模式Multi-Head (h=8)8 种不同的注意力模式h0h1h2h3h4h5h6h7vs单头 = 一个视角;多头 = 多元视角

多头的数学表达

完整公式

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \, W^O

其中每个头的计算为:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, \, K W_i^K, \, V W_i^V)

各参数矩阵的形状:

  • WiQRH×dkW_i^Q \in \mathbb{R}^{H \times d_k} — 第 ii 个头的 Query 投影
  • WiKRH×dkW_i^K \in \mathbb{R}^{H \times d_k} — 第 ii 个头的 Key 投影
  • WiVRH×dvW_i^V \in \mathbb{R}^{H \times d_v} — 第 ii 个头的 Value 投影
  • WORhdv×HW^O \in \mathbb{R}^{h \cdot d_v \times H} — 输出投影矩阵

其中 HH 是模型的隐藏维度,hh 是头的数量,dk=dv=H/hd_k = d_v = H / h

与单头 Attention 的关系

单头 Attention 在完整的 HH 维空间中计算:

SingleHead(Q,K,V)=Attention(Q,K,V)\text{SingleHead}(Q, K, V) = \text{Attention}(Q, K, V)

多头 Attention 则将 HH 维空间切分为 hhdkd_k 维子空间,在每个子空间中独立计算注意力,最后合并:

H=h×dkH = h \times d_k

关键点:多头 Attention 的总参数量和计算量与单头(使用完整维度)几乎相同。我们没有增加成本,而是将计算分配到了多个并行的子空间中。

空间切分的直觉:不同 head 关注不同模式

为什么在子空间中计算注意力比在完整空间中更好?直觉是每个 head 可以学习一种独立的”关注模式”。

研究者在分析训练好的 Transformer 时发现,不同 head 确实学到了不同的功能(参见原论文附录和后续研究):

Head 类型关注模式示例
位置 Head总是关注相邻 tokenHead 3 关注前一个 token
语法 Head关注句法依存关系Head 7 关注动词的主语
语义 Head关注语义相关的 tokenHead 5 关注共指关系
分隔 Head关注特殊 token(如 [SEP])Head 11 关注句子分隔符

每个 head 在自己的低维子空间中,通过独立的 WiQW_i^QWiKW_i^KWiVW_i^V 学习不同的投影方式,从而在不同的”视角”下计算注意力。

不同 Head 的注意力模式特化不同 Head 自动学习不同模式Head 0: 位置相邻Head 1: 语法依赖Head 2: 语义关联Head 3: 分隔符Head 4: 距离递减Head 5: 段落边界Head 6: 稀疏全局Head 7: 均匀分布
单头 Attention — 所有模式混在一起
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
多头 Attention (h=4) — 每个 head 关注不同模式
Head 1: 局部模式
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
Head 2: 动词-主语
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
Head 3: 代词指代
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
Head 4: 介词短语
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired

示意图,非真实模型权重 — 展示多头如何让不同 head 专注于不同关系模式

维度分析:reshape 和 transpose 的详细追踪

在实际实现中,我们不会为每个 head 单独做矩阵乘法 — 这太慢了。取而代之的是,用一次大的投影 + reshape + transpose,高效地并行计算所有 head。

详细的 tensor shape 变化

以 Query 为例(Key 和 Value 同理),设 batch size 为 BB,序列长度为 SS,隐藏维度为 HH,头数为 hh,每头维度为 dk=H/hd_k = H / h

第一步:线性投影

XRB×S×HWQQRB×S×HX \in \mathbb{R}^{B \times S \times H} \xrightarrow{W^Q} Q \in \mathbb{R}^{B \times S \times H}

这里的 WQRH×HW^Q \in \mathbb{R}^{H \times H} 是一个统一的大投影矩阵,等价于 hh 个小投影矩阵 WiQW_i^Q 的拼接。

第二步:reshape — 切分 head

QRB×S×HreshapeQRB×S×h×dkQ \in \mathbb{R}^{B \times S \times H} \xrightarrow{\text{reshape}} Q \in \mathbb{R}^{B \times S \times h \times d_k}

将最后一个维度从 HH 切分为 h×dkh \times d_k

第三步:transpose — 让 head 维度提前

QRB×S×h×dktransposeQRB×h×S×dkQ \in \mathbb{R}^{B \times S \times h \times d_k} \xrightarrow{\text{transpose}} Q \in \mathbb{R}^{B \times h \times S \times d_k}

交换 SShh 维度。这样每个 head 对应一个 (S,dk)(S, d_k) 的矩阵,可以对 B×hB \times h 个”batch”并行计算 attention。

第四步:Scaled Dot-Product Attention

QRB×h×S×dk,KRB×h×S×dk,VRB×h×S×dkQ \in \mathbb{R}^{B \times h \times S \times d_k}, \quad K \in \mathbb{R}^{B \times h \times S \times d_k}, \quad V \in \mathbb{R}^{B \times h \times S \times d_k} Scores=QKTdkRB×h×S×S\text{Scores} = \frac{QK^T}{\sqrt{d_k}} \in \mathbb{R}^{B \times h \times S \times S} Output=softmax(Scores)VRB×h×S×dk\text{Output} = \text{softmax}(\text{Scores}) \cdot V \in \mathbb{R}^{B \times h \times S \times d_k}

第五步:transpose 回来 + reshape

RB×h×S×dktransposeRB×S×h×dkreshapeRB×S×H\mathbb{R}^{B \times h \times S \times d_k} \xrightarrow{\text{transpose}} \mathbb{R}^{B \times S \times h \times d_k} \xrightarrow{\text{reshape}} \mathbb{R}^{B \times S \times H}

将所有 head 的输出重新拼接成 HH 维向量。

第六步:输出投影

RB×S×HWORB×S×H\mathbb{R}^{B \times S \times H} \xrightarrow{W^O} \mathbb{R}^{B \times S \times H}

PyTorch 伪代码

# 线性投影 (一次性计算所有 head)
Q = self.W_q(X)                      # (B, S, H)
K = self.W_k(X)                      # (B, S, H)
V = self.W_v(X)                      # (B, S, H)

# reshape + transpose
Q = Q.view(B, S, h, d_k).transpose(1, 2)  # (B, h, S, d_k)
K = K.view(B, S, h, d_k).transpose(1, 2)  # (B, h, S, d_k)
V = V.view(B, S, h, d_k).transpose(1, 2)  # (B, h, S, d_k)

# Scaled Dot-Product Attention (对所有 head 并行)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)  # (B, h, S, S)
weights = F.softmax(scores, dim=-1)                  # (B, h, S, S)
output = weights @ V                                  # (B, h, S, d_k)

# transpose + reshape + 输出投影
output = output.transpose(1, 2).contiguous().view(B, S, H)  # (B, S, H)
output = self.W_o(output)                                     # (B, S, H)

多头并行计算结构图

Multi-Head Attention 计算结构输入 X: (B, S, H)Linear W_Q: (H, H)Linear W_K: (H, H)Linear W_V: (H, H)reshape + transpose → (B,h,S,d_k)reshape + transpose → (B,h,S,d_k)reshape + transpose → (B,h,S,d_k)h 个 Head 并行计算Head 1Attention(Q₁,K₁,V₁)(B, 1, S, d_k)Head 2Attention(Q₂,K₂,V₂)(B, 1, S, d_k)...Head hAttention(Q_h,K_h,V_h)(B, 1, S, d_k)每个 Head 内部:1. QK^T / √d_k2. + Mask3. Softmax4. × VConcat → reshape: (B, S, H)Linear W_O: (H, H)输出: (B, S, H)

输出投影:WOW^O 的作用

输出投影 W^O 的作用各头输出 (d_k=64)h0h1h2h3h4h5h6h7Concat拼接 (512 维)×W^O[512, 512]输出d=512Concat 只是拼接,W^O 才真正混合各头视角

多头拼接后,我们得到的向量已经回到 HH 维。那为什么还需要一个输出投影 WOW^O

作用一:融合多头信息

各 head 在自己的子空间中独立计算,彼此没有交互。WOW^O 提供了一个跨 head 的线性变换,让不同 head 捕获的信息能够混合和交互。

可以这样理解:

  • 每个 head 是一个”专家”,捕获某种特定的关注模式
  • WOW^O 是一个”融合层”,将各专家的意见综合为最终决策

作用二:维持残差连接的兼容性

Transformer 中每个子层的输出需要与输入做残差连接:Output=SubLayer(X)+X\text{Output} = \text{SubLayer}(X) + XWOW^O 确保多头注意力的输出与输入具有相同的维度和数值范围,使残差连接能够正常工作。

各 Head 独立输出

4 个 head 各自计算 Attention 后得到 (d_k=3) 维的输出向量,每个 head 用不同颜色标识。

Head 0 output (d_k=3)
h0d0
h0d1
h0d2
Head 1 output (d_k=3)
h1d0
h1d1
h1d2
Head 2 output (d_k=3)
h2d0
h2d1
h2d2
Head 3 output (d_k=3)
h3d0
h3d1
h3d2

参数量分析

对于一个多头注意力层,投影矩阵的参数量为:

3×H2WQ,WK,WV+H2WO=4H2\underbrace{3 \times H^2}_{W^Q, W^K, W^V} + \underbrace{H^2}_{W^O} = 4H^2

注意 hhWiQW_i^Q(每个 H×dkH \times d_k)拼接后等价于一个 H×HH \times H 的大矩阵 — 所以无论 head 数如何变化,总参数量不变

典型配置:不同模型的 head 设计

模型隐藏维度 HH头数 hh每头维度 dkd_k层数
Transformer (原论文)5128646
GPT-2 (Small)768126412
GPT-2 (Medium)1024166424
GPT-3 (175B)122889612896
LLaMA-7B40963212832
LLaMA-65B81926412880

有趣的观察

  • 小模型通常使用 dk=64d_k = 64(原论文的选择)
  • 大模型倾向于使用 dk=128d_k = 128(更大的子空间容量)
  • 头数随模型规模增长,但 dkd_k 通常保持不变
  • 这暗示”子空间的粒度”有一个合理的范围,更多的 head 意味着更多并行的关注模式

变体:Multi-Query Attention 和 Grouped-Query Attention

近年来出现了一些效率优化变体:

  • Multi-Query Attention (MQA):所有 head 共享同一组 K 和 V,只有 Q 不同。大幅减少 KV cache 的内存占用。
  • Grouped-Query Attention (GQA):折中方案,将 hh 个 head 分为 gg 组,每组共享 K 和 V。LLaMA-2 70B 使用 8 组 GQA。

这些变体的核心权衡是:用少量精度损失换取显著的推理效率提升。

总结

Multi-Head Attention 的设计哲学是分而治之

概念说明
为什么多头单头只能学一种注意力模式,多头可以并行关注不同关系
空间切分H=h×dkH = h \times d_k,每个 head 在 dkd_k 维子空间独立计算
实现技巧reshape + transpose 让所有 head 并行计算,无额外开销
维度流转(B,S,H)(B,h,S,dk)Attention(B,S,H)(B,S,H) \to (B,h,S,d_k) \to \text{Attention} \to (B,S,H)
输出投影WOW^O 融合各 head 信息,保持维度兼容
参数量4H24H^2,与 head 数无关

核心直觉:Multi-Head Attention 不是简单地”多做几次 Attention”。它的精妙之处在于,通过将高维空间切分为多个低维子空间,在不增加计算量的前提下,让模型能够同时从多个角度理解 token 之间的关系。这就像用多台摄像机从不同角度拍摄同一个场景 — 每个角度都捕获了独特的信息,合在一起才是完整的画面。