在上一篇文章中,我们了解了 Q、K、V 三个矩阵是如何通过线性投影得到的。本文将深入拆解 Attention 的计算过程 — 从 Q、K、V 出发,一步步推导出最终输出。
Attention 机制的本质是一种可微的软检索 :用 Query 去匹配所有 Key,根据匹配程度对 Value 做加权平均。每个 token 的输出不再是固定的,而是根据上下文动态聚合的。
完整公式:Scaled Dot-Product Attention
Transformer 采用的 Attention 形式称为 Scaled Dot-Product Attention :
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V Attention ( Q , K , V ) = softmax ( d k Q K T ) V
其中:
Q ∈ R S × d k Q \in \mathbb{R}^{S \times d_k} Q ∈ R S × d k — Query 矩阵
K ∈ R S × d k K \in \mathbb{R}^{S \times d_k} K ∈ R S × d k — Key 矩阵
V ∈ R S × d v V \in \mathbb{R}^{S \times d_v} V ∈ R S × d v — Value 矩阵(通常 d v = d k d_v = d_k d v = d k )
d k d_k d k — 每个注意力头的维度
d k \sqrt{d_k} d k — 缩放因子,防止点积过大
这个公式看似简洁,但包含了五个关键步骤。下面逐一拆解。
简化模式 (单头) 完整模式 (多头+batch)
Input X (S, H) Q = X·Wq (S, H) K = X·Wk (S, H) V = X·Wv (S, H) reshape (S, d_k) Q·Kᵀ (S, S) ÷ √d_k (S, S) + mask (S, S) softmax (S, S) × V (S, d_v) concat (S, H) × Wo (S, H)
Input X
(S, H)
输入序列的隐藏表示
简化模式省略了 batch (B) 和 多头 (h) 维度
Attention 分步计算 Scaled Dot-Product Attention 四步分解 1 Q·K^T 原始分数 → [S, S] 2 ÷√d_k 缩放 → [S, S] 3 +Mask 因果遮罩 → [S, S] 4 softmax×V 加权输出 → [S, d_v] 每步的矩阵形状:Q·K^T → [S,S] → 缩放 → 遮罩 → softmax → ×V → [S,d_v] 高注意力 遮蔽区域
分步拆解:每一步的数学意义
第一步:Q K T QK^T Q K T — 计算原始注意力分数
Scores = Q K T ∈ R S × S \text{Scores} = QK^T \in \mathbb{R}^{S \times S} Scores = Q K T ∈ R S × S
这是一个矩阵乘法:Q Q Q 的形状为 ( S , d k ) (S, d_k) ( S , d k ) ,K T K^T K T 的形状为 ( d k , S ) (d_k, S) ( d k , S ) ,结果为 ( S , S ) (S, S) ( S , S ) 。
直觉: 结果矩阵的第 i i i 行第 j j j 列是 Query 向量 q i q_i q i 与 Key 向量 k j k_j k j 的点积:
Scores i j = q i ⋅ k j = ∑ l = 1 d k q i l ⋅ k j l \text{Scores}_{ij} = q_i \cdot k_j = \sum_{l=1}^{d_k} q_{il} \cdot k_{jl} Scores ij = q i ⋅ k j = l = 1 ∑ d k q i l ⋅ k j l
点积衡量两个向量的”相似度”:值越大,表示 token i i i 对 token j j j 的关注度越高。
第二步:除以 d k \sqrt{d_k} d k — 缩放
Scaled = Q K T d k \text{Scaled} = \frac{QK^T}{\sqrt{d_k}} Scaled = d k Q K T
为什么需要缩放?这不是一个随意的设计,而是基于严格的统计分析。详见后文”Scaling 的必要性”一节。
第三步:Mask — 遮罩(可选)
Masked i j = { Scaled i j if j ≤ i − ∞ if j > i \text{Masked}_{ij} = \begin{cases} \text{Scaled}_{ij} & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases} Masked ij = { Scaled ij − ∞ if j ≤ i if j > i
在 Decoder 的自注意力中,token i i i 不能看到位置 i i i 之后的 token(因为那些 token 在自回归生成时还不存在)。通过将上三角设为 − ∞ -\infty − ∞ ,softmax 后对应权重变为 0。详见后文”Causal Mask”一节。
第四步:Softmax — 行归一化
Weights i = softmax ( Masked i ) = e Masked i j ∑ j e Masked i j \text{Weights}_i = \text{softmax}(\text{Masked}_i) = \frac{e^{\text{Masked}_{ij}}}{\sum_{j} e^{\text{Masked}_{ij}}} Weights i = softmax ( Masked i ) = ∑ j e Masked ij e Masked ij
对分数矩阵的每一行 独立做 softmax,将原始分数转化为概率分布(非负且和为 1)。
第五步:乘以 V V V — 加权求和
Output = Weights ⋅ V ∈ R S × d v \text{Output} = \text{Weights} \cdot V \in \mathbb{R}^{S \times d_v} Output = Weights ⋅ V ∈ R S × d v
权重矩阵 ( S , S ) (S, S) ( S , S ) 乘以 Value 矩阵 ( S , d v ) (S, d_v) ( S , d v ) ,得到最终输出 ( S , d v ) (S, d_v) ( S , d v ) 。每个 token 的输出是所有 Value 向量的加权平均:
Output i = ∑ j = 1 S Weights i j ⋅ v j \text{Output}_i = \sum_{j=1}^{S} \text{Weights}_{ij} \cdot v_j Output i = j = 1 ∑ S Weights ij ⋅ v j
交互动画:Attention 计算全过程
下面用一个小例子(S = 4 S=4 S = 4 , d k = 3 d_k=3 d k = 3 )演示上述五个步骤的完整计算过程。点击”下一步”逐步查看。
1 2 3 4 5 6 Q 和 K 矩阵
从上一步的线性投影中,我们已经得到了 Q 和 K 矩阵,形状都是 (S=4, d_k=3)。接下来要计算它们之间的注意力分数。
Scaling 的必要性:为什么除以 d k \sqrt{d_k} d k
这是面试和学习中最常被问到的问题之一。原论文(Vaswani et al., 2017)给出了明确的解释:
统计分析
假设 q q q 和 k k k 的每个分量都是独立的随机变量,均值为 0,方差为 1。那么它们的点积:
q ⋅ k = ∑ l = 1 d k q l ⋅ k l q \cdot k = \sum_{l=1}^{d_k} q_l \cdot k_l q ⋅ k = l = 1 ∑ d k q l ⋅ k l
的统计性质为:
E [ q ⋅ k ] = 0 , Var ( q ⋅ k ) = d k \mathbb{E}[q \cdot k] = 0, \quad \text{Var}(q \cdot k) = d_k E [ q ⋅ k ] = 0 , Var ( q ⋅ k ) = d k
方差推导:每个 q l ⋅ k l q_l \cdot k_l q l ⋅ k l 的方差为 Var ( q l ) ⋅ Var ( k l ) = 1 \text{Var}(q_l) \cdot \text{Var}(k_l) = 1 Var ( q l ) ⋅ Var ( k l ) = 1 (因为均值为 0 的随机变量之积的方差等于各自方差之积),d k d_k d k 个独立项求和后方差为 d k d_k d k 。
问题
当 d k d_k d k 较大时(例如 GPT-3 使用 d k = 128 d_k = 128 d k = 128 ),点积的量级约为 128 ≈ 11.3 \sqrt{128} \approx 11.3 128 ≈ 11.3 。这意味着 softmax 的输入值会非常大,导致:
Softmax 输出接近 one-hot :softmax ( [ 10 , 1 , 1 ] ) ≈ [ 0.9999 , 0.0001 , 0.0001 ] \text{softmax}([10, 1, 1]) \approx [0.9999, 0.0001, 0.0001] softmax ([ 10 , 1 , 1 ]) ≈ [ 0.9999 , 0.0001 , 0.0001 ]
梯度几乎消失 :在 softmax 的饱和区,梯度趋近于 0,模型无法有效学习
解决方案
除以 d k \sqrt{d_k} d k 后,点积的方差恢复为 1:
Var ( q ⋅ k d k ) = d k d k = 1 \text{Var}\!\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1 Var ( d k q ⋅ k ) = d k d k = 1
这样 softmax 的输入保持在合理范围内,梯度流畅,训练稳定。
Head 维度 dk : 64
未缩放: QKT
k1 k2 k3 k4 k5 k6 k7 k8 方差: 5.14 · Softmax 熵: 2.056 bits
→ Softmax 输出
k1 k2 k3 k4 k5 k6 k7 k8 缩放后: QKT / √dk
k1 k2 k3 k4 k5 k6 k7 k8 方差: 0.08 · Softmax 熵: 2.947 bits
→ Softmax 输出
k1 k2 k3 k4 k5 k6 k7 k8 观察: d越大 → 未缩放分数的方差越大 → Softmax 输出越接近 one-hot(熵趋近 0)。除以 √d后方差恢复到 ~1,Softmax 输出保持均匀分布(熵接近 3.0 bits)。
原论文原话:“We suspect that for large values of d k d_k d k , the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.”
Causal Mask:Decoder 的因果遮罩
为什么需要遮罩
在自回归(autoregressive)语言模型中,生成 token i i i 时只能看到 token 1 , 2 , … , i 1, 2, \ldots, i 1 , 2 , … , i ,不能看到 i + 1 , i + 2 , … i+1, i+2, \ldots i + 1 , i + 2 , … (因为它们还没被生成)。
训练时为了并行化,我们会一次性输入整个序列,但需要通过遮罩模拟”看不到未来”的效果。
遮罩矩阵
M i j = { 0 if j ≤ i − ∞ if j > i M_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases} M ij = { 0 − ∞ if j ≤ i if j > i
将遮罩加到缩放后的分数上:Masked = Scaled + M \text{Masked} = \text{Scaled} + M Masked = Scaled + M
因为 e − ∞ = 0 e^{-\infty} = 0 e − ∞ = 0 ,所以 softmax 后被遮罩位置的权重为 0。
遮罩的形状
对于序列长度 S = 4 S = 4 S = 4 :
M = ( 0 − ∞ − ∞ − ∞ 0 0 − ∞ − ∞ 0 0 0 − ∞ 0 0 0 0 ) M = \begin{pmatrix} 0 & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty \\ 0 & 0 & 0 & -\infty \\ 0 & 0 & 0 & 0 \end{pmatrix} M = 0 0 0 0 − ∞ 0 0 0 − ∞ − ∞ 0 0 − ∞ − ∞ − ∞ 0
这是一个下三角矩阵。第 i i i 行只保留前 i i i 个位置的分数。
1 2 3 原始 QKᵀ 分数矩阵
原始点积分数 — 每个格子表示 token i 对 token j 的原始相关性。
Scores = QKᵀ/√d_k
The cat sat on it The cat sat on it 1.50 2.73 -0.75 2.43 -0.21 0.09 1.68 1.65 0.06 0.24 -0.06 -1.26 -0.33 -0.15 -2.16 -0.66 -1.20 -0.63 -0.54 0.18 0.87 2.07 -2.76 1.56 0.66
不同场景的遮罩策略
场景 遮罩类型 说明 Encoder Self-Attention 无遮罩或 padding mask 双向注意力,可以看到整个序列 Decoder Self-Attention 因果遮罩 只能看到当前和之前的 token Cross-Attention padding mask Decoder 查询 Encoder 输出,无因果约束
Softmax 的数值稳定性
Softmax 数值稳定性 朴素实现 ∞ exp(1000) ∞ exp(999) ∞ exp(998) → Inf / NaN 减去最大值 1.00 exp(0) 0.37 exp(-1) 0.14 exp(-2) → 正常计算 vs 稳定版公式:softmax(x_i) = exp(x_i - max(x)) / Σ exp(x_j - max(x)) 数学等价,但保证 exp 输入 ≤ 0,不会溢出
溢出问题
朴素的 softmax 实现:
softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax ( x i ) = ∑ j e x j e x i
当 x i x_i x i 很大时(例如 x i = 1000 x_i = 1000 x i = 1000 ),e 1000 e^{1000} e 1000 会超出浮点数表示范围,导致数值溢出(得到 Inf 或 NaN)。
标准技巧:减去最大值
softmax ( x i ) = e x i − max ( x ) ∑ j e x j − max ( x ) \text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} softmax ( x i ) = ∑ j e x j − m a x ( x ) e x i − m a x ( x )
数学上完全等价(分子分母同乘 e − max ( x ) e^{-\max(x)} e − m a x ( x ) ),但保证指数的输入 ≤ 0 \leq 0 ≤ 0 ,从而 e x i − max ( x ) ≤ 1 e^{x_i - \max(x)} \leq 1 e x i − m a x ( x ) ≤ 1 ,不会溢出。
证明等价性:
e x i − m ∑ j e x j − m = e x i ⋅ e − m ∑ j e x j ⋅ e − m = e x i ∑ j e x j \frac{e^{x_i - m}}{\sum_j e^{x_j - m}} = \frac{e^{x_i} \cdot e^{-m}}{\sum_j e^{x_j} \cdot e^{-m}} = \frac{e^{x_i}}{\sum_j e^{x_j}} ∑ j e x j − m e x i − m = ∑ j e x j ⋅ e − m e x i ⋅ e − m = ∑ j e x j e x i
其中 m = max ( x ) m = \max(x) m = max ( x ) 。
实际代码中的实现
所有主流深度学习框架(PyTorch、JAX、TensorFlow)的 softmax 实现都内置了这个技巧。在 Flash Attention 等优化实现中,如何在分块计算中维持数值稳定性是一个更复杂的问题,我们将在后续文章中讨论。
总结
Scaled Dot-Product Attention 的计算可以分解为五个清晰的步骤:
步骤 操作 输出形状 作用 1 Q K T QK^T Q K T ( S , S ) (S, S) ( S , S ) 计算所有 token 对之间的相似度 2 ÷ d k \div \sqrt{d_k} ÷ d k ( S , S ) (S, S) ( S , S ) 防止点积过大导致梯度消失 3 + Mask + \text{Mask} + Mask ( S , S ) (S, S) ( S , S ) 遮蔽不应被关注的位置 4 Softmax ( S , S ) (S, S) ( S , S ) 归一化为概率分布 5 × V \times V × V ( S , d v ) (S, d_v) ( S , d v ) 按注意力权重聚合 Value
核心直觉 :Attention 本质上是一种”软寻址”机制 — 每个 token 根据自己的 Query 和所有 Key 的匹配程度,从所有 Value 中提取信息。缩放保证训练稳定,遮罩保证因果性。
下一篇文章将介绍 Multi-Head Attention — 如何将多个注意力头并行运算并组合,进一步提升模型的表达能力。