简介:为什么单头不够
在上一篇文章中,我们详细拆解了 Scaled Dot-Product Attention 的计算过程。但在实际的 Transformer 中,注意力并不是只计算一次 — 而是并行计算多次,这就是 Multi-Head Attention。
为什么需要多个头?考虑这个句子:
“The animal didn’t cross the street because it was too tired.”
对于 “it” 这个 token,我们需要同时关注多种不同的关系:
- 指代关系:it → animal(语义共指)
- 句法关系:it → was(主谓搭配)
- 因果关系:it → because(逻辑连接)
如果只有一个注意力头,它的 softmax 输出是一个概率分布 — 只能产生一种权重模式。这意味着模型必须将所有不同类型的关系”混合”到同一组权重中,严重限制了表达能力。
Multi-Head Attention 的核心思想:让不同的注意力头在不同的子空间中独立计算注意力,每个头可以关注不同类型的模式,最后将结果拼接起来。
多头的数学表达
完整公式
MultiHead(Q,K,V)=Concat(head1,…,headh)WO
其中每个头的计算为:
headi=Attention(QWiQ,KWiK,VWiV)
各参数矩阵的形状:
- WiQ∈RH×dk — 第 i 个头的 Query 投影
- WiK∈RH×dk — 第 i 个头的 Key 投影
- WiV∈RH×dv — 第 i 个头的 Value 投影
- WO∈Rh⋅dv×H — 输出投影矩阵
其中 H 是模型的隐藏维度,h 是头的数量,dk=dv=H/h。
与单头 Attention 的关系
单头 Attention 在完整的 H 维空间中计算:
SingleHead(Q,K,V)=Attention(Q,K,V)
多头 Attention 则将 H 维空间切分为 h 个 dk 维子空间,在每个子空间中独立计算注意力,最后合并:
H=h×dk
关键点:多头 Attention 的总参数量和计算量与单头(使用完整维度)几乎相同。我们没有增加成本,而是将计算分配到了多个并行的子空间中。
空间切分的直觉:不同 head 关注不同模式
为什么在子空间中计算注意力比在完整空间中更好?直觉是每个 head 可以学习一种独立的”关注模式”。
研究者在分析训练好的 Transformer 时发现,不同 head 确实学到了不同的功能(参见原论文附录和后续研究):
| Head 类型 | 关注模式 | 示例 |
|---|
| 位置 Head | 总是关注相邻 token | Head 3 关注前一个 token |
| 语法 Head | 关注句法依存关系 | Head 7 关注动词的主语 |
| 语义 Head | 关注语义相关的 token | Head 5 关注共指关系 |
| 分隔 Head | 关注特殊 token(如 [SEP]) | Head 11 关注句子分隔符 |
每个 head 在自己的低维子空间中,通过独立的 WiQ、WiK、WiV 学习不同的投影方式,从而在不同的”视角”下计算注意力。
单头 Attention — 所有模式混在一起
多头 Attention (h=4) — 每个 head 关注不同模式
Head 1: 局部模式
Head 2: 动词-主语
Head 3: 代词指代
Head 4: 介词短语
示意图,非真实模型权重 — 展示多头如何让不同 head 专注于不同关系模式
维度分析:reshape 和 transpose 的详细追踪
在实际实现中,我们不会为每个 head 单独做矩阵乘法 — 这太慢了。取而代之的是,用一次大的投影 + reshape + transpose,高效地并行计算所有 head。
详细的 tensor shape 变化
以 Query 为例(Key 和 Value 同理),设 batch size 为 B,序列长度为 S,隐藏维度为 H,头数为 h,每头维度为 dk=H/h:
第一步:线性投影
X∈RB×S×HWQQ∈RB×S×H
这里的 WQ∈RH×H 是一个统一的大投影矩阵,等价于 h 个小投影矩阵 WiQ 的拼接。
第二步:reshape — 切分 head
Q∈RB×S×HreshapeQ∈RB×S×h×dk
将最后一个维度从 H 切分为 h×dk。
第三步:transpose — 让 head 维度提前
Q∈RB×S×h×dktransposeQ∈RB×h×S×dk
交换 S 和 h 维度。这样每个 head 对应一个 (S,dk) 的矩阵,可以对 B×h 个”batch”并行计算 attention。
第四步:Scaled Dot-Product Attention
Q∈RB×h×S×dk,K∈RB×h×S×dk,V∈RB×h×S×dk
Scores=dkQKT∈RB×h×S×S
Output=softmax(Scores)⋅V∈RB×h×S×dk
第五步:transpose 回来 + reshape
RB×h×S×dktransposeRB×S×h×dkreshapeRB×S×H
将所有 head 的输出重新拼接成 H 维向量。
第六步:输出投影
RB×S×HWORB×S×H
PyTorch 伪代码
# 线性投影 (一次性计算所有 head)
Q = self.W_q(X) # (B, S, H)
K = self.W_k(X) # (B, S, H)
V = self.W_v(X) # (B, S, H)
# reshape + transpose
Q = Q.view(B, S, h, d_k).transpose(1, 2) # (B, h, S, d_k)
K = K.view(B, S, h, d_k).transpose(1, 2) # (B, h, S, d_k)
V = V.view(B, S, h, d_k).transpose(1, 2) # (B, h, S, d_k)
# Scaled Dot-Product Attention (对所有 head 并行)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # (B, h, S, S)
weights = F.softmax(scores, dim=-1) # (B, h, S, S)
output = weights @ V # (B, h, S, d_k)
# transpose + reshape + 输出投影
output = output.transpose(1, 2).contiguous().view(B, S, H) # (B, S, H)
output = self.W_o(output) # (B, S, H)
多头并行计算结构图
输出投影:WO 的作用
多头拼接后,我们得到的向量已经回到 H 维。那为什么还需要一个输出投影 WO?
作用一:融合多头信息
各 head 在自己的子空间中独立计算,彼此没有交互。WO 提供了一个跨 head 的线性变换,让不同 head 捕获的信息能够混合和交互。
可以这样理解:
- 每个 head 是一个”专家”,捕获某种特定的关注模式
- WO 是一个”融合层”,将各专家的意见综合为最终决策
作用二:维持残差连接的兼容性
Transformer 中每个子层的输出需要与输入做残差连接:Output=SubLayer(X)+X。WO 确保多头注意力的输出与输入具有相同的维度和数值范围,使残差连接能够正常工作。
各 Head 独立输出
4 个 head 各自计算 Attention 后得到 (d_k=3) 维的输出向量,每个 head 用不同颜色标识。
参数量分析
对于一个多头注意力层,投影矩阵的参数量为:
WQ,WK,WV3×H2+WOH2=4H2
注意 h 个 WiQ(每个 H×dk)拼接后等价于一个 H×H 的大矩阵 — 所以无论 head 数如何变化,总参数量不变。
典型配置:不同模型的 head 设计
| 模型 | 隐藏维度 H | 头数 h | 每头维度 dk | 层数 |
|---|
| Transformer (原论文) | 512 | 8 | 64 | 6 |
| GPT-2 (Small) | 768 | 12 | 64 | 12 |
| GPT-2 (Medium) | 1024 | 16 | 64 | 24 |
| GPT-3 (175B) | 12288 | 96 | 128 | 96 |
| LLaMA-7B | 4096 | 32 | 128 | 32 |
| LLaMA-65B | 8192 | 64 | 128 | 80 |
有趣的观察:
- 小模型通常使用 dk=64(原论文的选择)
- 大模型倾向于使用 dk=128(更大的子空间容量)
- 头数随模型规模增长,但 dk 通常保持不变
- 这暗示”子空间的粒度”有一个合理的范围,更多的 head 意味着更多并行的关注模式
变体:Multi-Query Attention 和 Grouped-Query Attention
近年来出现了一些效率优化变体:
- Multi-Query Attention (MQA):所有 head 共享同一组 K 和 V,只有 Q 不同。大幅减少 KV cache 的内存占用。
- Grouped-Query Attention (GQA):折中方案,将 h 个 head 分为 g 组,每组共享 K 和 V。LLaMA-2 70B 使用 8 组 GQA。
这些变体的核心权衡是:用少量精度损失换取显著的推理效率提升。
总结
Multi-Head Attention 的设计哲学是分而治之:
| 概念 | 说明 |
|---|
| 为什么多头 | 单头只能学一种注意力模式,多头可以并行关注不同关系 |
| 空间切分 | H=h×dk,每个 head 在 dk 维子空间独立计算 |
| 实现技巧 | reshape + transpose 让所有 head 并行计算,无额外开销 |
| 维度流转 | (B,S,H)→(B,h,S,dk)→Attention→(B,S,H) |
| 输出投影 | WO 融合各 head 信息,保持维度兼容 |
| 参数量 | 4H2,与 head 数无关 |
核心直觉:Multi-Head Attention 不是简单地”多做几次 Attention”。它的精妙之处在于,通过将高维空间切分为多个低维子空间,在不增加计算量的前提下,让模型能够同时从多个角度理解 token 之间的关系。这就像用多台摄像机从不同角度拍摄同一个场景 — 每个角度都捕获了独特的信息,合在一起才是完整的画面。