Multi-Head Attention
Updated 2026-04-06
Introduction: Why a Single Head Is Not Enough
In the previous article, we dissected the computation of Scaled Dot-Product Attention in detail. But in actual Transformers, attention is not computed just once — it is computed multiple times in parallel, which is Multi-Head Attention.
Why do we need multiple heads? Consider this sentence:
“The animal didn’t cross the street because it was too tired.”
For the token “it,” we need to attend to multiple different types of relationships simultaneously:
- Coreference: it → animal (semantic coreference)
- Syntactic relationship: it → was (subject-verb agreement)
- Causal relationship: it → because (logical connection)
If there is only one attention head, its softmax output is a single probability distribution — it can only produce one pattern of weights. This means the model must “mix” all different types of relationships into the same set of weights, severely limiting expressiveness.
The core idea of Multi-Head Attention: let different attention heads independently compute attention in different subspaces, where each head can focus on different types of patterns, and then concatenate the results.
Mathematical Formulation of Multi-Head
Complete Formula
Where each head is computed as:
The shapes of the parameter matrices:
- — Query projection for the -th head
- — Key projection for the -th head
- — Value projection for the -th head
- — Output projection matrix
Where is the model’s hidden dimension, is the number of heads, and .
Relationship to Single-Head Attention
Single-head Attention computes in the full -dimensional space:
Multi-Head Attention splits the -dimensional space into subspaces of dimensions, computes attention independently in each subspace, and then merges:
Key point: The total parameter count and computation of Multi-Head Attention is nearly identical to single-head (using the full dimension). We have not increased cost, but rather distributed the computation across multiple parallel subspaces.
Intuition Behind Subspace Splitting: Different Heads Attend to Different Patterns
Why is computing attention in subspaces better than in the full space? The intuition is that each head can learn an independent “attention pattern.”
Researchers analyzing trained Transformers have found that different heads indeed learn different functions (see the original paper’s appendix and subsequent studies):
| Head Type | Attention Pattern | Example |
|---|---|---|
| Positional Head | Always attends to adjacent tokens | Head 3 attends to the previous token |
| Syntactic Head | Attends to syntactic dependencies | Head 7 attends to the verb’s subject |
| Semantic Head | Attends to semantically related tokens | Head 5 attends to coreference relations |
| Separator Head | Attends to special tokens (e.g., [SEP]) | Head 11 attends to sentence separators |
Each head, within its own low-dimensional subspace, learns different projection patterns through independent , , , thereby computing attention from different “perspectives.”
Illustrative diagram, not real model weights — shows how multi-head allows different heads to focus on different relational patterns
Dimension Analysis: Detailed Tracking of Reshape and Transpose
In practice, we do not perform separate matrix multiplications for each head — that would be too slow. Instead, we use a single large projection + reshape + transpose to efficiently compute all heads in parallel.
Detailed Tensor Shape Transformations
Using Query as an example (Key and Value follow the same pattern), with batch size , sequence length , hidden dimension , number of heads , and per-head dimension :
Step 1: Linear Projection
Here is a unified large projection matrix, equivalent to the concatenation of small projection matrices .
Step 2: Reshape — Split into Heads
The last dimension is split from into .
Step 3: Transpose — Move the Head Dimension Forward
The and dimensions are swapped. This way, each head corresponds to a matrix, and attention can be computed in parallel across “batches.”
Step 4: Scaled Dot-Product Attention
Step 5: Transpose Back + Reshape
All heads’ outputs are concatenated back into an -dimensional vector.
Step 6: Output Projection
PyTorch Pseudocode
# Linear projection (compute all heads at once)
Q = self.W_q(X) # (B, S, H)
K = self.W_k(X) # (B, S, H)
V = self.W_v(X) # (B, S, H)
# reshape + transpose
Q = Q.view(B, S, h, d_k).transpose(1, 2) # (B, h, S, d_k)
K = K.view(B, S, h, d_k).transpose(1, 2) # (B, h, S, d_k)
V = V.view(B, S, h, d_k).transpose(1, 2) # (B, h, S, d_k)
# Scaled Dot-Product Attention (parallel across all heads)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # (B, h, S, S)
weights = F.softmax(scores, dim=-1) # (B, h, S, S)
output = weights @ V # (B, h, S, d_k)
# transpose + reshape + output projection
output = output.transpose(1, 2).contiguous().view(B, S, H) # (B, S, H)
output = self.W_o(output) # (B, S, H)
Multi-Head Parallel Computation Diagram
Output Projection: The Role of
After multi-head concatenation, the resulting vector is already back to dimensions. So why do we still need an output projection ?
Role 1: Fusing Multi-Head Information
Each head computes independently in its own subspace with no cross-head interaction. provides a cross-head linear transformation, allowing information captured by different heads to mix and interact.
Think of it this way:
- Each head is an “expert” capturing a specific attention pattern
- is a “fusion layer” that synthesizes all experts’ opinions into a final decision
Role 2: Maintaining Residual Connection Compatibility
In Transformers, each sublayer’s output needs to be added to the input via a residual connection: . ensures that the multi-head attention output has the same dimension and value range as the input, allowing the residual connection to function properly.
4 heads each compute Attention to get (d_k=3) dimensional output vectors, each head identified by different color.
Parameter Count Analysis
For a multi-head attention layer, the parameter count of the projection matrices is:
Note that matrices (each ) concatenated together are equivalent to a single matrix — so regardless of how the number of heads changes, the total parameter count remains the same.
Typical Configurations: Head Designs Across Different Models
| Model | Hidden Dim | Heads | Per-Head Dim | Layers |
|---|---|---|---|---|
| Transformer (original paper) | 512 | 8 | 64 | 6 |
| GPT-2 (Small) | 768 | 12 | 64 | 12 |
| GPT-2 (Medium) | 1024 | 16 | 64 | 24 |
| GPT-3 (175B) | 12288 | 96 | 128 | 96 |
| LLaMA-7B | 4096 | 32 | 128 | 32 |
| LLaMA-65B | 8192 | 64 | 128 | 80 |
Interesting observations:
- Smaller models typically use (the original paper’s choice)
- Larger models tend to use (greater subspace capacity)
- The number of heads grows with model size, but usually stays constant
- This suggests there is a reasonable range for “subspace granularity,” and more heads means more parallel attention patterns
Variants: Multi-Query Attention and Grouped-Query Attention
In recent years, several efficiency-optimized variants have emerged:
- Multi-Query Attention (MQA): All heads share the same set of K and V; only Q differs. This drastically reduces the memory footprint of the KV cache.
- Grouped-Query Attention (GQA): A compromise approach that divides heads into groups, with each group sharing K and V. LLaMA-2 70B uses 8-group GQA.
The core trade-off of these variants is: accepting a small accuracy loss in exchange for significant inference efficiency gains.
Summary
The design philosophy of Multi-Head Attention is divide and conquer:
| Concept | Description |
|---|---|
| Why multiple heads | A single head can learn only one attention pattern; multiple heads can attend to different relationships in parallel |
| Subspace splitting | , each head computes independently in a -dimensional subspace |
| Implementation trick | reshape + transpose enables parallel computation across all heads with no extra overhead |
| Dimension flow | |
| Output projection | fuses information from all heads and maintains dimensional compatibility |
| Parameter count | , independent of the number of heads |
Core intuition: Multi-Head Attention is not simply “doing Attention multiple times.” Its elegance lies in splitting a high-dimensional space into multiple low-dimensional subspaces, enabling the model to understand token relationships from multiple perspectives simultaneously without increasing computation. It is like filming the same scene with multiple cameras from different angles — each angle captures unique information, and only together do they form the complete picture.