本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

Attention 计算详解

Attention 计算详解

更新于 2026-04-23

简介:Attention 是 Transformer 的核心

在上一篇文章中,我们了解了 Q、K、V 三个矩阵是如何通过线性投影得到的。本文将深入拆解 Attention 的计算过程 — 从 Q、K、V 出发,一步步推导出最终输出。

Attention 机制的本质是一种可微的软检索:用 Query 去匹配所有 Key,根据匹配程度对 Value 做加权平均。每个 token 的输出不再是固定的,而是根据上下文动态聚合的。

完整公式:Scaled Dot-Product Attention

Transformer 采用的 Attention 形式称为 Scaled Dot-Product Attention

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

其中:

  • QRS×dkQ \in \mathbb{R}^{S \times d_k} — Query 矩阵
  • KRS×dkK \in \mathbb{R}^{S \times d_k} — Key 矩阵
  • VRS×dvV \in \mathbb{R}^{S \times d_v} — Value 矩阵(通常 dv=dkd_v = d_k
  • dkd_k — 每个注意力头的维度
  • dk\sqrt{d_k} — 缩放因子,防止点积过大

这个公式看似简洁,但包含了五个关键步骤。下面逐一拆解。

Input X(S, H)Q = X·Wq(S, H)K = X·Wk(S, H)V = X·Wv(S, H)reshape(S, d_k)Q·Kᵀ(S, S)÷ √d_k(S, S)+ mask(S, S)softmax(S, S)× V(S, d_v)concat(S, H)× Wo(S, H)
Input X
(S, H)

输入序列的隐藏表示

简化模式省略了 batch (B) 和 多头 (h) 维度

Attention 分步计算Scaled Dot-Product Attention 四步分解1Q·K^T原始分数[S, S]2÷√d_k缩放[S, S]3+Mask因果遮罩[S, S]4softmax×V加权输出[S, d_v]每步的矩阵形状:Q·K^T → [S,S] → 缩放 → 遮罩 → softmax → ×V → [S,d_v]高注意力遮蔽区域

分步拆解:每一步的数学意义

第一步:QKTQK^T — 计算原始注意力分数

Scores=QKTRS×S\text{Scores} = QK^T \in \mathbb{R}^{S \times S}

这是一个矩阵乘法:QQ 的形状为 (S,dk)(S, d_k)KTK^T 的形状为 (dk,S)(d_k, S),结果为 (S,S)(S, S)

直觉: 结果矩阵的第 ii 行第 jj 列是 Query 向量 qiq_i 与 Key 向量 kjk_j 的点积:

Scoresij=qikj=l=1dkqilkjl\text{Scores}_{ij} = q_i \cdot k_j = \sum_{l=1}^{d_k} q_{il} \cdot k_{jl}

点积衡量两个向量的”相似度”:值越大,表示 token ii 对 token jj 的关注度越高。

第二步:除以 dk\sqrt{d_k} — 缩放

Scaled=QKTdk\text{Scaled} = \frac{QK^T}{\sqrt{d_k}}

为什么需要缩放?这不是一个随意的设计,而是基于严格的统计分析。详见后文”Scaling 的必要性”一节。

第三步:Mask — 遮罩(可选)

Maskedij={Scaledijif jiif j>i\text{Masked}_{ij} = \begin{cases} \text{Scaled}_{ij} & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}

在 Decoder 的自注意力中,token ii 不能看到位置 ii 之后的 token(因为那些 token 在自回归生成时还不存在)。通过将上三角设为 -\infty,softmax 后对应权重变为 0。详见后文”Causal Mask”一节。

第四步:Softmax — 行归一化

Weightsi=softmax(Maskedi)=eMaskedijjeMaskedij\text{Weights}_i = \text{softmax}(\text{Masked}_i) = \frac{e^{\text{Masked}_{ij}}}{\sum_{j} e^{\text{Masked}_{ij}}}

对分数矩阵的每一行独立做 softmax,将原始分数转化为概率分布(非负且和为 1)。

第五步:乘以 VV — 加权求和

Output=WeightsVRS×dv\text{Output} = \text{Weights} \cdot V \in \mathbb{R}^{S \times d_v}

权重矩阵 (S,S)(S, S) 乘以 Value 矩阵 (S,dv)(S, d_v),得到最终输出 (S,dv)(S, d_v)。每个 token 的输出是所有 Value 向量的加权平均:

Outputi=j=1SWeightsijvj\text{Output}_i = \sum_{j=1}^{S} \text{Weights}_{ij} \cdot v_j

交互动画:Attention 计算全过程

下面用一个小例子(S=4S=4, dk=3d_k=3)演示上述五个步骤的完整计算过程。点击”下一步”逐步查看。

Q 和 K 矩阵

从上一步的线性投影中,我们已经得到了 Q 和 K 矩阵,形状都是 (S=4, d_k=3)。接下来要计算它们之间的注意力分数。

Q ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
0.11
0.77
0.08
t₂
-0.55
-0.28
0.24
t₃
-0.79
0.97
-0.74
t₄
-0.72
0.95
0.36
(4, 3)
K ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
-0.89
-0.34
-0.17
t₂
-0.17
0.81
-0.10
t₃
0.86
-0.99
0.30
t₄
0.10
-0.91
0.92
(4, 3)

Scaling 的必要性:为什么除以 dk\sqrt{d_k}

这是面试和学习中最常被问到的问题之一。原论文(Vaswani et al., 2017)给出了明确的解释:

统计分析

假设 qqkk 的每个分量都是独立的随机变量,均值为 0,方差为 1。那么它们的点积:

qk=l=1dkqlklq \cdot k = \sum_{l=1}^{d_k} q_l \cdot k_l

的统计性质为:

E[qk]=0,Var(qk)=dk\mathbb{E}[q \cdot k] = 0, \quad \text{Var}(q \cdot k) = d_k

方差推导:每个 qlklq_l \cdot k_l 的方差为 Var(ql)Var(kl)=1\text{Var}(q_l) \cdot \text{Var}(k_l) = 1(因为均值为 0 的随机变量之积的方差等于各自方差之积),dkd_k 个独立项求和后方差为 dkd_k

问题

dkd_k 较大时(例如 GPT-3 使用 dk=128d_k = 128),点积的量级约为 12811.3\sqrt{128} \approx 11.3。这意味着 softmax 的输入值会非常大,导致:

  1. Softmax 输出接近 one-hotsoftmax([10,1,1])[0.9999,0.0001,0.0001]\text{softmax}([10, 1, 1]) \approx [0.9999, 0.0001, 0.0001]
  2. 梯度几乎消失:在 softmax 的饱和区,梯度趋近于 0,模型无法有效学习

解决方案

除以 dk\sqrt{d_k} 后,点积的方差恢复为 1:

Var ⁣(qkdk)=dkdk=1\text{Var}\!\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1

这样 softmax 的输入保持在合理范围内,梯度流畅,训练稳定。

未缩放: QKT
k1k2k3k4k5k6k7k8
方差: 5.14 · Softmax 熵: 2.056 bits
→ Softmax 输出
k1k2k3k4k5k6k7k8
缩放后: QKT / √dk
k1k2k3k4k5k6k7k8
方差: 0.08 · Softmax 熵: 2.947 bits
→ Softmax 输出
k1k2k3k4k5k6k7k8
观察:d越大 → 未缩放分数的方差越大 → Softmax 输出越接近 one-hot(熵趋近 0)。除以 √d后方差恢复到 ~1,Softmax 输出保持均匀分布(熵接近 3.0 bits)。

原论文原话:“We suspect that for large values of dkd_k, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.”

Causal Mask:Decoder 的因果遮罩

为什么需要遮罩

在自回归(autoregressive)语言模型中,生成 token ii 时只能看到 token 1,2,,i1, 2, \ldots, i,不能看到 i+1,i+2,i+1, i+2, \ldots(因为它们还没被生成)。

训练时为了并行化,我们会一次性输入整个序列,但需要通过遮罩模拟”看不到未来”的效果。

遮罩矩阵

Mij={0if jiif j>iM_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}

将遮罩加到缩放后的分数上:Masked=Scaled+M\text{Masked} = \text{Scaled} + M

因为 e=0e^{-\infty} = 0,所以 softmax 后被遮罩位置的权重为 0。

遮罩的形状

对于序列长度 S=4S = 4

M=(0000000000)M = \begin{pmatrix} 0 & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty \\ 0 & 0 & 0 & -\infty \\ 0 & 0 & 0 & 0 \end{pmatrix}

这是一个下三角矩阵。第 ii 行只保留前 ii 个位置的分数。

原始 QKᵀ 分数矩阵

原始点积分数 — 每个格子表示 token i 对 token j 的原始相关性。

Scores = QKᵀ/√d_k
ThecatsatonitThecatsatonit1.502.73-0.752.43-0.210.091.681.650.060.24-0.06-1.26-0.33-0.15-2.16-0.66-1.20-0.63-0.540.180.872.07-2.761.560.66

不同场景的遮罩策略

场景遮罩类型说明
Encoder Self-Attention无遮罩或 padding mask双向注意力,可以看到整个序列
Decoder Self-Attention因果遮罩只能看到当前和之前的 token
Cross-Attentionpadding maskDecoder 查询 Encoder 输出,无因果约束

Softmax 的数值稳定性

Softmax 数值稳定性朴素实现exp(1000)exp(999)exp(998)→ Inf / NaN减去最大值1.00exp(0)0.37exp(-1)0.14exp(-2)→ 正常计算vs稳定版公式:softmax(x_i) = exp(x_i - max(x)) / Σ exp(x_j - max(x))数学等价,但保证 exp 输入 ≤ 0,不会溢出

溢出问题

朴素的 softmax 实现:

softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

xix_i 很大时(例如 xi=1000x_i = 1000),e1000e^{1000} 会超出浮点数表示范围,导致数值溢出(得到 InfNaN)。

标准技巧:减去最大值

softmax(xi)=eximax(x)jexjmax(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}

数学上完全等价(分子分母同乘 emax(x)e^{-\max(x)}),但保证指数的输入 0\leq 0,从而 eximax(x)1e^{x_i - \max(x)} \leq 1,不会溢出。

证明等价性:

eximjexjm=exiemjexjem=exijexj\frac{e^{x_i - m}}{\sum_j e^{x_j - m}} = \frac{e^{x_i} \cdot e^{-m}}{\sum_j e^{x_j} \cdot e^{-m}} = \frac{e^{x_i}}{\sum_j e^{x_j}}

其中 m=max(x)m = \max(x)

实际代码中的实现

所有主流深度学习框架(PyTorch、JAX、TensorFlow)的 softmax 实现都内置了这个技巧。在 Flash Attention 等优化实现中,如何在分块计算中维持数值稳定性是一个更复杂的问题,我们将在后续文章中讨论。

总结

Scaled Dot-Product Attention 的计算可以分解为五个清晰的步骤:

步骤操作输出形状作用
1QKTQK^T(S,S)(S, S)计算所有 token 对之间的相似度
2÷dk\div \sqrt{d_k}(S,S)(S, S)防止点积过大导致梯度消失
3+Mask+ \text{Mask}(S,S)(S, S)遮蔽不应被关注的位置
4Softmax(S,S)(S, S)归一化为概率分布
5×V\times V(S,dv)(S, d_v)按注意力权重聚合 Value

核心直觉:Attention 本质上是一种”软寻址”机制 — 每个 token 根据自己的 Query 和所有 Key 的匹配程度,从所有 Value 中提取信息。缩放保证训练稳定,遮罩保证因果性。

下一篇文章将介绍 Multi-Head Attention — 如何将多个注意力头并行运算并组合,进一步提升模型的表达能力。