Content on this site is AI-generated and may contain errors. If you find issues, please report at GitHub Issues .

MQA and GQA

MQA and GQA

Updated 2026-04-06

Introduction: Why MHA’s KV Cache Is the Bottleneck

In the previous article, we learned how Multi-Head Attention (MHA) works: hh heads each have their own independent WQW^Q, WKW^K, WVW^V projection matrices, computing attention in parallel across different subspaces.

MHA performs excellently during training — all tokens can be processed in parallel. But during inference (autoregressive generation), a serious efficiency bottleneck emerges: the KV Cache.

What Is KV Cache

In autoregressive generation, each new token needs to compute attention with all preceding tokens. To avoid redundant computation, we cache the Key and Value vectors of previous tokens — this is the KV Cache.

For standard MHA, the KV Cache size per layer is:

KV Cache per layer=2×h×S×dk×sizeof(dtype)\text{KV Cache per layer} = 2 \times h \times S \times d_k \times \text{sizeof(dtype)}

Where hh is the number of heads, SS is the sequence length, dkd_k is the per-head dimension, and the factor of 2 accounts for both K and V caches.

Taking LLaMA-2 70B as an example (h=64h = 64, dk=128d_k = 128, L=80L = 80 layers), when generating a sequence of length S=4096S = 4096, the KV Cache per request is:

2×64×4096×128×80×2 bytes10.7 GB (FP16)2 \times 64 \times 4096 \times 128 \times 80 \times 2\text{ bytes} \approx 10.7 \text{ GB (FP16)}

This is even larger than the GPU memory occupied by the model parameters themselves! When serving multiple users simultaneously (batch serving), the KV Cache quickly exhausts GPU memory, becoming the core bottleneck for throughput.

Key observation: The KV Cache size is proportional to the number of heads hh. If we can reduce the number of KV heads that need to be cached, we can directly shrink the KV Cache.

MQA: The Most Aggressive KV Reduction

Multi-Query Attention (MQA) was proposed by Noam Shazeer in 2019 (“Fast Transformer Decoding: One Write-Head is All You Need”) and represents the most aggressive KV reduction approach.

Core Idea

The modification in MQA is simple: all query heads share a single set of Key and Value.

  • Each head still has its own independent WiQW_i^Q (Query projections remain distinct)
  • But all heads share one WKW^K and one WVW^V

Mathematical formulation:

headi=Attention(XWiQ,XWK,XWV)\text{head}_i = \text{Attention}(X W_i^Q, \, X W^K, \, X W^V)

Note that WKW^K and WVW^V no longer have the subscript ii — they are shared across all heads.

KV Cache Reduction

Since there is only one set of KV, the cache size shrinks from:

2×h×S×dkMQA2×1×S×dk2 \times h \times S \times d_k \quad \xrightarrow{\text{MQA}} \quad 2 \times 1 \times S \times d_k

A reduction of hh times! For a model with h=64h = 64, this means the KV Cache shrinks to 164\frac{1}{64} of the original.

The Cost

MQA’s reduction is extreme and comes with notable drawbacks:

  • Quality degradation: All heads are forced to compute attention in the same KV subspace, losing MHA’s ability for “different heads to attend to different patterns”
  • Training instability: Training an MQA model from scratch may be harder to converge
  • The paper reports “only minor quality degradation,” but in practice, quality loss can be more noticeable on downstream tasks, especially those requiring fine-grained reasoning

GQA: The Grouped-Query Compromise

Grouped-Query Attention (GQA) was proposed by Ainslie et al. in 2023, offering an elegant compromise between MHA and MQA.

Core Idea

Divide the hh query heads into gg groups, with each group sharing a pair of KV heads.

  • When g=hg = h (one head per group) → degenerates to standard MHA
  • When g=1g = 1 (all heads in one group) → degenerates to MQA
  • When 1<g<h1 < g < h → GQA, the compromise approach

Mathematical formulation (where the ii-th query head belongs to group ig/h\lfloor i \cdot g / h \rfloor):

headi=Attention(XWiQ,XWgroup(i)K,XWgroup(i)V)\text{head}_i = \text{Attention}(X W_i^Q, \, X W_{\text{group}(i)}^K, \, X W_{\text{group}(i)}^V)

KV Cache Reduction

2×h×S×dkGQA2×g×S×dk2 \times h \times S \times d_k \quad \xrightarrow{\text{GQA}} \quad 2 \times g \times S \times d_k

A reduction of h/gh / g times. For example, with h=64h = 64 and g=8g = 8, the KV Cache shrinks to 18\frac{1}{8} of the original.

Key Innovation: Uptraining

Another important contribution of the GQA paper is proposing a method to uptrain from an existing MHA checkpoint to GQA:

  1. Initialize the shared KV head by taking the mean of the multiple KV head weights within each group from the original MHA model
  2. Only about 5% of the original pretraining compute is needed to complete the conversion
  3. The converted model’s quality approaches the original MHA, while inference speed approaches MQA

This means there is no need to train from scratch — existing MHA models can be efficiently converted to GQA models.

Step 1: 8 KV Head Weight Matrices

Original MHA model has 8 independent KV heads, each with its own weight matrix.

KV1
KV2
KV3
KV4
KV5
KV6
KV7
KV8

Structural Comparison: MHA vs MQA vs GQA

MHA — Each Q has its own KV

Standard Multi-Head Attention: 4 Q heads each have independent KV heads.

MHAOne-to-OneQ₁Q₂Q₃Q₄KV₁KV₂KV₃KV₄KV heads = h = 4

The diagram above shows the head-to-KV mapping relationships for the three attention mechanisms (using h=4h = 4 as an example):

  • MHA: 4 Query heads each map to 1 independent KV head (4 KV heads total)
  • GQA (g=2g = 2): 4 Query heads are divided into 2 groups, each sharing 1 KV head (2 KV heads total)
  • MQA: All 4 Query heads share 1 KV head (1 KV head total)

KV Cache Memory Analysis: Concrete Numbers

Let us compute the KV Cache memory footprint using real model parameters. Assuming sequence length S=4096S = 4096 and FP16 (2 bytes per element):

ModelLayers LLhhdkd_kKV headsKV Cache / request
Hypothetical MHA-70B806412864 (MHA)10.7 GB
LLaMA-2 70B (GQA)806412881.3 GB
Hypothetical MQA-70B80641281 (MQA)0.17 GB

Computation formula:

KV Cache=2×KV heads×S×dk×L×2 bytes\text{KV Cache} = 2 \times \text{KV heads} \times S \times d_k \times L \times 2\text{ bytes}

Taking LLaMA-2 70B’s GQA configuration as an example:

2×8×4096×128×80×2=1.34 GB2 \times 8 \times 4096 \times 128 \times 80 \times 2 = 1.34 \text{ GB}

Comparison: From MHA’s 10.7 GB down to GQA’s 1.3 GB, a reduction of approximately 8x (h/g=64/8=8h / g = 64 / 8 = 8), and with MQA it could be reduced 64x to just 0.17 GB.

h=64, kv_heads=8, L=80, d_k=128
0.0GB80.0GB160.0GB240.0GB320.0GB2561K4K16K64K
MHA (320.0 GB @ 128K) GQA (40.0 GB) MQA (5.00 GB)

Impact on Batch Serving

The impact of KV Cache reduction on batch inference is even more significant. Assuming the GPU has 40 GB of remaining memory available for KV Cache:

ApproachKV Cache / requestMax concurrent requests
MHA10.7 GB~3
GQA (8 groups)1.3 GB~30
MQA0.17 GB~235

GQA increases concurrency capacity by approximately 10x — this has a decisive impact on cost and latency for LLM serving.

40 GB
MHA
4 concurrent requests(10.00 GB/req)
GQA
32 concurrent requests(1.25 GB/req)
MQA
256 concurrent requests(0.16 GB/req)

Based on LLaMA-2 70B parameters (L=80, h=64, d_k=128, GQA kv_heads=8), FP16

Quality vs Performance Trade-off

Reducing the number of KV heads is fundamentally a form of information compression: forcing multiple query heads to find different attention patterns within the same KV subspace.

Why GQA’s Quality Loss Is Small

  1. Redundancy: Research has found that adjacent heads’ KV projections in MHA are often highly similar — many heads learn redundant KV representations
  2. Query diversity preserved: GQA retains the independence of all query heads; only the KV space is shared. Query projections can still learn different attention patterns within the shared KV space
  3. Uptraining effectiveness: Initializing from an MHA checkpoint via mean pooling + a small amount of continued training can efficiently recover quality

The GQA paper reports that with approximately 5% of the original pretraining compute for uptraining, the GQA model performs close to the original MHA model on most benchmarks, while achieving inference speeds close to MQA.

Sources of Speed Improvement

The speed improvements from KV Cache reduction come primarily from two aspects:

  1. Memory bandwidth: Autoregressive decoding is a memory-bandwidth-bound operation. A smaller KV Cache means less data needs to be loaded per generation step, directly improving generation speed
  2. Memory capacity: A smaller KV Cache allows larger batch sizes, improving GPU utilization and overall throughput

Real-World Adoption

GQA has become the standard configuration in current mainstream large language models:

ModelQuery HeadsKV HeadsGroup Ratio (h/g)Attention Type
LLaMA-2 7B32321:1MHA
LLaMA-2 13B40401:1MHA
LLaMA-2 70B6488:1GQA
LLaMA-3 8B3284:1GQA
LLaMA-3 70B6488:1GQA
Mistral 7B3284:1GQA
Gemini 1.0 Pro1MQA

Notable trends:

  • LLaMA-2 series: Only the largest 70B model uses GQA, while the smaller 7B and 13B still use standard MHA. This indicates that at the time, the KV Cache bottleneck was primarily a concern for large models
  • LLaMA-3 series: All sizes (including 8B) adopt GQA, reflecting that GQA has been proven effective across all scales
  • Mistral 7B: Uses GQA (4:1) even at the 7B scale, combined with sliding window attention to further optimize inference efficiency
  • Gemini 1.0 Pro: Uses the more aggressive MQA approach, with all query heads sharing a single KV head
  • Industry consensus: GQA has become the default choice for new models, with 8 KV heads being a common configuration

PyTorch Implementation Notes

Implementing GQA requires only minor modifications to standard MHA:

# GQA: g KV heads, h query heads, each group of h//g queries shares one KV
class GroupedQueryAttention(nn.Module):
    def __init__(self, H, h, g, d_k):
        super().__init__()
        self.h = h      # number of query heads
        self.g = g      # number of KV heads (number of groups)
        self.d_k = d_k
        
        self.W_q = nn.Linear(H, h * d_k)    # h query heads
        self.W_k = nn.Linear(H, g * d_k)    # g KV heads
        self.W_v = nn.Linear(H, g * d_k)    # g KV heads
        self.W_o = nn.Linear(h * d_k, H)
    
    def forward(self, x):
        B, S, _ = x.shape
        
        Q = self.W_q(x).view(B, S, self.h, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, S, self.g, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, S, self.g, self.d_k).transpose(1, 2)
        
        # Key step: expand KV heads to match the number of query heads
        # Each KV head is repeated h//g times
        repeats = self.h // self.g
        K = K.repeat_interleave(repeats, dim=1)  # (B, h, S, d_k)
        V = V.repeat_interleave(repeats, dim=1)  # (B, h, S, d_k)
        
        # Standard attention computation
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
        weights = F.softmax(scores, dim=-1)
        output = weights @ V
        
        output = output.transpose(1, 2).contiguous().view(B, S, -1)
        return self.W_o(output)

The key difference is that W_k and W_v have output dimension g×dkg \times d_k (instead of h×dkh \times d_k), and then repeat_interleave is used to replicate each KV head h/gh/g times to match the number of query heads. Note that this replication does not increase the KV Cache size — only gg KV heads are actually cached.

Summary

ConceptDescription
MHA bottleneckKV Cache grows linearly with the number of heads, limiting inference efficiency and concurrency
MQAAll query heads share one KV pair, KV Cache reduced by hh times, but with noticeable quality loss
GQAhh query heads divided into gg groups, each sharing KV — a compromise approach
KV Cache reductionMHA: 2hSdk2hSd_k → GQA: 2gSdk2gSd_k → MQA: 2Sdk2Sd_k
UptrainingConvert from an MHA checkpoint using ~5% of compute
Industry trendGQA has become the standard configuration in LLaMA-3, Mistral, and other mainstream models

Core intuition: A significant amount of information redundancy exists among KV heads in MHA. GQA allows a group of query heads to share the same KV head pair, reducing the KV Cache by several times with virtually no quality loss, thereby dramatically improving inference efficiency and serving concurrency. It is like a team meeting — not everyone needs to bring their own complete set of meeting materials; a few people can share one copy. What is saved is desk space (GPU memory), without affecting discussion quality (model capability).