Content on this site is AI-generated and may contain errors. If you find issues, please report at GitHub Issues .

Multi-Head Attention

Multi-Head Attention

Updated 2026-04-06

Introduction: Why a Single Head Is Not Enough

In the previous article, we dissected the computation of Scaled Dot-Product Attention in detail. But in actual Transformers, attention is not computed just once — it is computed multiple times in parallel, which is Multi-Head Attention.

Why do we need multiple heads? Consider this sentence:

“The animal didn’t cross the street because it was too tired.”

For the token “it,” we need to attend to multiple different types of relationships simultaneously:

  • Coreference: it → animal (semantic coreference)
  • Syntactic relationship: it → was (subject-verb agreement)
  • Causal relationship: it → because (logical connection)

If there is only one attention head, its softmax output is a single probability distribution — it can only produce one pattern of weights. This means the model must “mix” all different types of relationships into the same set of weights, severely limiting expressiveness.

The core idea of Multi-Head Attention: let different attention heads independently compute attention in different subspaces, where each head can focus on different types of patterns, and then concatenate the results.

Mathematical Formulation of Multi-Head

Complete Formula

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \, W^O

Where each head is computed as:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, \, K W_i^K, \, V W_i^V)

The shapes of the parameter matrices:

  • WiQRH×dkW_i^Q \in \mathbb{R}^{H \times d_k} — Query projection for the ii-th head
  • WiKRH×dkW_i^K \in \mathbb{R}^{H \times d_k} — Key projection for the ii-th head
  • WiVRH×dvW_i^V \in \mathbb{R}^{H \times d_v} — Value projection for the ii-th head
  • WORhdv×HW^O \in \mathbb{R}^{h \cdot d_v \times H} — Output projection matrix

Where HH is the model’s hidden dimension, hh is the number of heads, and dk=dv=H/hd_k = d_v = H / h.

Relationship to Single-Head Attention

Single-head Attention computes in the full HH-dimensional space:

SingleHead(Q,K,V)=Attention(Q,K,V)\text{SingleHead}(Q, K, V) = \text{Attention}(Q, K, V)

Multi-Head Attention splits the HH-dimensional space into hh subspaces of dkd_k dimensions, computes attention independently in each subspace, and then merges:

H=h×dkH = h \times d_k

Key point: The total parameter count and computation of Multi-Head Attention is nearly identical to single-head (using the full dimension). We have not increased cost, but rather distributed the computation across multiple parallel subspaces.

Intuition Behind Subspace Splitting: Different Heads Attend to Different Patterns

Why is computing attention in subspaces better than in the full space? The intuition is that each head can learn an independent “attention pattern.”

Researchers analyzing trained Transformers have found that different heads indeed learn different functions (see the original paper’s appendix and subsequent studies):

Head TypeAttention PatternExample
Positional HeadAlways attends to adjacent tokensHead 3 attends to the previous token
Syntactic HeadAttends to syntactic dependenciesHead 7 attends to the verb’s subject
Semantic HeadAttends to semantically related tokensHead 5 attends to coreference relations
Separator HeadAttends to special tokens (e.g., [SEP])Head 11 attends to sentence separators

Each head, within its own low-dimensional subspace, learns different projection patterns through independent WiQW_i^Q, WiKW_i^K, WiVW_i^V, thereby computing attention from different “perspectives.”

Single-Head Attention — All patterns mixed
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
Multi-Head Attention (h=4) — Each head focuses on different patterns
Head 1: Local pattern
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
Head 2: Verb-subject
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
Head 3: Pronoun reference
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired
Head 4: Prepositional phrase
TheThecatcatsatsatononthethematmatbecausebecauseititwaswastiredtired

Illustrative diagram, not real model weights — shows how multi-head allows different heads to focus on different relational patterns

Dimension Analysis: Detailed Tracking of Reshape and Transpose

In practice, we do not perform separate matrix multiplications for each head — that would be too slow. Instead, we use a single large projection + reshape + transpose to efficiently compute all heads in parallel.

Detailed Tensor Shape Transformations

Using Query as an example (Key and Value follow the same pattern), with batch size BB, sequence length SS, hidden dimension HH, number of heads hh, and per-head dimension dk=H/hd_k = H / h:

Step 1: Linear Projection

XRB×S×HWQQRB×S×HX \in \mathbb{R}^{B \times S \times H} \xrightarrow{W^Q} Q \in \mathbb{R}^{B \times S \times H}

Here WQRH×HW^Q \in \mathbb{R}^{H \times H} is a unified large projection matrix, equivalent to the concatenation of hh small projection matrices WiQW_i^Q.

Step 2: Reshape — Split into Heads

QRB×S×HreshapeQRB×S×h×dkQ \in \mathbb{R}^{B \times S \times H} \xrightarrow{\text{reshape}} Q \in \mathbb{R}^{B \times S \times h \times d_k}

The last dimension is split from HH into h×dkh \times d_k.

Step 3: Transpose — Move the Head Dimension Forward

QRB×S×h×dktransposeQRB×h×S×dkQ \in \mathbb{R}^{B \times S \times h \times d_k} \xrightarrow{\text{transpose}} Q \in \mathbb{R}^{B \times h \times S \times d_k}

The SS and hh dimensions are swapped. This way, each head corresponds to a (S,dk)(S, d_k) matrix, and attention can be computed in parallel across B×hB \times h “batches.”

Step 4: Scaled Dot-Product Attention

QRB×h×S×dk,KRB×h×S×dk,VRB×h×S×dkQ \in \mathbb{R}^{B \times h \times S \times d_k}, \quad K \in \mathbb{R}^{B \times h \times S \times d_k}, \quad V \in \mathbb{R}^{B \times h \times S \times d_k} Scores=QKTdkRB×h×S×S\text{Scores} = \frac{QK^T}{\sqrt{d_k}} \in \mathbb{R}^{B \times h \times S \times S} Output=softmax(Scores)VRB×h×S×dk\text{Output} = \text{softmax}(\text{Scores}) \cdot V \in \mathbb{R}^{B \times h \times S \times d_k}

Step 5: Transpose Back + Reshape

RB×h×S×dktransposeRB×S×h×dkreshapeRB×S×H\mathbb{R}^{B \times h \times S \times d_k} \xrightarrow{\text{transpose}} \mathbb{R}^{B \times S \times h \times d_k} \xrightarrow{\text{reshape}} \mathbb{R}^{B \times S \times H}

All heads’ outputs are concatenated back into an HH-dimensional vector.

Step 6: Output Projection

RB×S×HWORB×S×H\mathbb{R}^{B \times S \times H} \xrightarrow{W^O} \mathbb{R}^{B \times S \times H}

PyTorch Pseudocode

# Linear projection (compute all heads at once)
Q = self.W_q(X)                      # (B, S, H)
K = self.W_k(X)                      # (B, S, H)
V = self.W_v(X)                      # (B, S, H)

# reshape + transpose
Q = Q.view(B, S, h, d_k).transpose(1, 2)  # (B, h, S, d_k)
K = K.view(B, S, h, d_k).transpose(1, 2)  # (B, h, S, d_k)
V = V.view(B, S, h, d_k).transpose(1, 2)  # (B, h, S, d_k)

# Scaled Dot-Product Attention (parallel across all heads)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)  # (B, h, S, S)
weights = F.softmax(scores, dim=-1)                  # (B, h, S, S)
output = weights @ V                                  # (B, h, S, d_k)

# transpose + reshape + output projection
output = output.transpose(1, 2).contiguous().view(B, S, H)  # (B, S, H)
output = self.W_o(output)                                     # (B, S, H)

Multi-Head Parallel Computation Diagram

Multi-Head Attention Computation StructureInput X: (B, S, H)Linear W_Q: (H, H)Linear W_K: (H, H)Linear W_V: (H, H)reshape + transpose → (B,h,S,d_k)reshape + transpose → (B,h,S,d_k)reshape + transpose → (B,h,S,d_k)h Heads parallel computationHead 1Attention(Q₁,K₁,V₁)(B, 1, S, d_k)Head 2Attention(Q₂,K₂,V₂)(B, 1, S, d_k)...Head hAttention(Q_h,K_h,V_h)(B, 1, S, d_k)Each Head inner:1. QK^T / √d_k2. + Mask3. Softmax4. × VConcat → reshape: (B, S, H)Linear W_O: (H, H)Output: (B, S, H)

Output Projection: The Role of WOW^O

After multi-head concatenation, the resulting vector is already back to HH dimensions. So why do we still need an output projection WOW^O?

Role 1: Fusing Multi-Head Information

Each head computes independently in its own subspace with no cross-head interaction. WOW^O provides a cross-head linear transformation, allowing information captured by different heads to mix and interact.

Think of it this way:

  • Each head is an “expert” capturing a specific attention pattern
  • WOW^O is a “fusion layer” that synthesizes all experts’ opinions into a final decision

Role 2: Maintaining Residual Connection Compatibility

In Transformers, each sublayer’s output needs to be added to the input via a residual connection: Output=SubLayer(X)+X\text{Output} = \text{SubLayer}(X) + X. WOW^O ensures that the multi-head attention output has the same dimension and value range as the input, allowing the residual connection to function properly.

Independent Head Outputs

4 heads each compute Attention to get (d_k=3) dimensional output vectors, each head identified by different color.

Head 0 output (d_k=3)
h0d0
h0d1
h0d2
Head 1 output (d_k=3)
h1d0
h1d1
h1d2
Head 2 output (d_k=3)
h2d0
h2d1
h2d2
Head 3 output (d_k=3)
h3d0
h3d1
h3d2

Parameter Count Analysis

For a multi-head attention layer, the parameter count of the projection matrices is:

3×H2WQ,WK,WV+H2WO=4H2\underbrace{3 \times H^2}_{W^Q, W^K, W^V} + \underbrace{H^2}_{W^O} = 4H^2

Note that hh matrices WiQW_i^Q (each H×dkH \times d_k) concatenated together are equivalent to a single H×HH \times H matrix — so regardless of how the number of heads changes, the total parameter count remains the same.

Typical Configurations: Head Designs Across Different Models

ModelHidden Dim HHHeads hhPer-Head Dim dkd_kLayers
Transformer (original paper)5128646
GPT-2 (Small)768126412
GPT-2 (Medium)1024166424
GPT-3 (175B)122889612896
LLaMA-7B40963212832
LLaMA-65B81926412880

Interesting observations:

  • Smaller models typically use dk=64d_k = 64 (the original paper’s choice)
  • Larger models tend to use dk=128d_k = 128 (greater subspace capacity)
  • The number of heads grows with model size, but dkd_k usually stays constant
  • This suggests there is a reasonable range for “subspace granularity,” and more heads means more parallel attention patterns

Variants: Multi-Query Attention and Grouped-Query Attention

In recent years, several efficiency-optimized variants have emerged:

  • Multi-Query Attention (MQA): All heads share the same set of K and V; only Q differs. This drastically reduces the memory footprint of the KV cache.
  • Grouped-Query Attention (GQA): A compromise approach that divides hh heads into gg groups, with each group sharing K and V. LLaMA-2 70B uses 8-group GQA.

The core trade-off of these variants is: accepting a small accuracy loss in exchange for significant inference efficiency gains.

Summary

The design philosophy of Multi-Head Attention is divide and conquer:

ConceptDescription
Why multiple headsA single head can learn only one attention pattern; multiple heads can attend to different relationships in parallel
Subspace splittingH=h×dkH = h \times d_k, each head computes independently in a dkd_k-dimensional subspace
Implementation trickreshape + transpose enables parallel computation across all heads with no extra overhead
Dimension flow(B,S,H)(B,h,S,dk)Attention(B,S,H)(B,S,H) \to (B,h,S,d_k) \to \text{Attention} \to (B,S,H)
Output projectionWOW^O fuses information from all heads and maintains dimensional compatibility
Parameter count4H24H^2, independent of the number of heads

Core intuition: Multi-Head Attention is not simply “doing Attention multiple times.” Its elegance lies in splitting a high-dimensional space into multiple low-dimensional subspaces, enabling the model to understand token relationships from multiple perspectives simultaneously without increasing computation. It is like filming the same scene with multiple cameras from different angles — each angle captures unique information, and only together do they form the complete picture.