One-sentence summary: MHA gives every head its own K and V (most expressive, most memory); MQA gives all heads one shared K/V (least memory, some quality loss); GQA splits the difference with group-sharing, and that is what most modern LLMs ship.


23.1 The Problem KV Cache Created

23.1.1 The Memory Math

Chapter 22 established that KV Cache is non-negotiable for production inference. But every head in Multi-Head Attention caches its own K and V, and that accumulates fast.

For a typical 7B-parameter model — 32 layers, 32 heads, head dimension 128, FP16:

KV Cache per request =
    32 layers × 32 heads × 2 (K and V) × seq_len × 128 × 2 bytes

At seq_len = 1024:

32 × 32 × 2 × 1024 × 128 × 2 = 536 MB

That is 536 MB for one user, one conversation, 1024 tokens. Scale to 100 concurrent users at 4096 tokens each:

536 MB × 4 (4096/1024) × 100 users  200 GB

200 GB of KV Cache. That is more than two fully loaded A100s just for the cache. This is why the industry started asking whether all those independent K/V heads are actually necessary.

23.1.2 The Root Tension

MHA was designed for training: every head learns a different projection, capturing different patterns. That is a feature. During inference, those independent projections become a burden — we must store one K/V set per head per layer per token.

Training wants expressiveness. Serving wants efficiency. MQA and GQA are the architectures born from that tension.

23.1.3 Three Mechanisms

MechanismFull nameCore idea
MHAMulti-Head AttentionEvery head has independent K, V
MQAMulti-Query AttentionAll heads share one K, one V
GQAGrouped-Query AttentionGroups of heads share one K/V each

23.2 MHA: The Baseline

23.2.1 Structure

MHA: each head has its own Q, K, V projections

In standard MHA with n_heads heads, every head has:

  • its own WQ(i)W_Q^{(i)} projection
  • its own WK(i)W_K^{(i)} projection
  • its own WV(i)W_V^{(i)} projection

KV Cache stores n_heads K tensors and n_heads V tensors per layer.

23.2.2 Why Multiple Heads Help

Different heads genuinely learn different things. In a sentence like "The agent tagged the reviewer because the PR was urgent":

  • Head 1 might track syntactic subject-verb: agent → tagged
  • Head 2 might track pronoun resolution: "the PR" ← which PR?
  • Head 3 might track causal reasoning: tagged → because → urgent
  • Head 4 might track recency, attending heavily to recent tokens

Independent K/V projections let each head build its own "perspective" on the token history. That is MHA's strength.

23.2.3 MHA Code Shape

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads

        # Full d_model projections  contains all heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)  # n_heads independent K projections
        self.W_v = nn.Linear(d_model, d_model)  # n_heads independent V projections
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
        k = self.W_k(x).view(B, T, self.n_heads, self.head_dim)
        v = self.W_v(x).view(B, T, self.n_heads, self.head_dim)
        # KV Cache stores n_heads sets of K and V

23.2.4 The Problem in Production

For a 32-head model, each layer needs 64 tensors in the KV Cache (32 K + 32 V). At long context or high concurrency, this becomes the binding constraint on how many requests you can serve.

An agent system doing tool-augmented reasoning at 16k context makes the problem concrete:

KV Cache per session (Llama-7B, 16k ctx, FP16):
    32 layers × 32 heads × 2 (K+V) × 16384 × 128 × 2 bytes  8 GB

8 GB per active session. On a GPU with 40 GB available (after loading the model weights), you can serve perhaps 4 concurrent long-context sessions. Scale that to a team, and you see the pressure to reduce KV Cache memory.


23.3 MQA: Collapse Everything

23.3.1 The Core Idea

Multi-Query Attention (Shazeer, 2019) makes a simple but aggressive choice: all query heads share one single K and one single V.

MQA: many Q heads, one shared K, one shared V
  • Q still has n_heads independent projections
  • K has 1 projection
  • V has 1 projection

KV Cache now stores 2 tensors per layer regardless of how many query heads exist.

23.3.2 Memory Savings

For the same 7B model at 1024 tokens:

MHA KV Cache = 32 layers × 32 heads × 2 × 1024 × 128 × 2 = 536 MB
MQA KV Cache = 32 layers ×  1 head  × 2 × 1024 × 128 × 2 = 16.75 MB

97% reduction. The same GPU that serves 5 concurrent users with MHA can now serve ~160 users with MQA.

23.3.3 MQA Code Shape

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)        # full n_heads
        self.W_k = nn.Linear(d_model, self.head_dim)  # only 1 head!
        self.W_v = nn.Linear(d_model, self.head_dim)  # only 1 head!
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
        k = self.W_k(x).view(B, T, 1, self.head_dim)
        v = self.W_v(x).view(B, T, 1, self.head_dim)
        # k, v broadcast from shape [B,T,1,head_dim] to [B,T,n_heads,head_dim]

23.3.4 The Cost

Forcing all query heads to consult the same K/V reference material limits each head's ability to build an independent view of the token history. MQA works well for many tasks but shows measurable quality degradation on tasks requiring diverse long-range pattern capture. Google's PaLM adopted MQA; the broader community found the quality loss hard to accept at frontier scale.


23.4 GQA: The Practical Middle Ground

23.4.1 Core Idea

Grouped-Query Attention (Ainslie et al., 2023) introduces one hyperparameter: n_kv_heads, the number of K/V groups.

Query heads are divided into n_kv_heads groups. All query heads within a group share one K projection and one V projection.

GQA: Q heads split into groups, each group shares one K and V

Formally:

  • n_heads — number of Q heads
  • n_kv_heads — number of KV groups
  • n_rep = n_heads / n_kv_heads — Q heads per group

Special cases:

  • n_kv_heads = n_heads → MHA (every head independent)
  • n_kv_heads = 1 → MQA (all heads share)
  • 1 < n_kv_heads < n_heads → GQA

23.4.2 GQA Code Shape

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.n_heads    = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep      = n_heads // n_kv_heads
        self.head_dim   = d_model // n_heads

        self.W_q = nn.Linear(d_model, n_heads    * self.head_dim)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        q = self.W_q(x).view(B, T, self.n_heads,    self.head_dim)
        k = self.W_k(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.W_v(x).view(B, T, self.n_kv_heads, self.head_dim)

        # Expand K and V to match Q head count
        k = self.repeat_kv(k)  # [B, T, n_heads, head_dim]
        v = self.repeat_kv(v)
        # Attention proceeds identically to MHA from here

    def repeat_kv(self, x):
        """Repeat each KV group n_rep times to match Q head count."""
        B, T, n_kv, head_dim = x.shape
        if self.n_rep == 1:
            return x
        # [B, T, n_kv, head_dim] -> [B, T, n_kv, n_rep, head_dim]
        x = x.unsqueeze(3).expand(B, T, n_kv, self.n_rep, head_dim)
        # -> [B, T, n_heads, head_dim]
        return x.reshape(B, T, self.n_heads, head_dim)

23.4.2b Training vs Inference

One nuance worth understanding: during training, KV Cache is not used (the whole sequence is processed in parallel with causal masking). So GQA's benefit during training is just the reduced parameter count from smaller K and V projection matrices — small but real.

During inference, the benefit is much larger. Decode is memory-bandwidth-bound (Chapter 22, Section 22.6.2): each generated token reads the entire KV Cache from HBM. Smaller KV Cache means more of it fits in SRAM, fewer HBM reads per token, and higher throughput. GQA's 4× or 8× memory reduction translates almost directly into faster decode.

23.4.3 Geometric View of repeat_kv

For n_heads = 8, n_kv_heads = 2:

Original K/V shape: [B, T, 2, head_dim]

  KV group 0          KV group 1

After repeat_kv: [B, T, 8, head_dim]

  Q0  Q1  Q2  Q3  Q4  Q5  Q6  Q7
                        
  KV0 KV0 KV0 KV0 KV1 KV1 KV1 KV1

Q heads 0–3 share KV group 0; Q heads 4–7 share KV group 1. Computationally this is a tensor repeat, not a separate projection — no parameters are added.


23.5 Three-Way Comparison

23.5.1 Memory Numbers

For a 7B model, 1024-token sequence, FP16:

KV Cache size comparison: MHA vs GQA vs MQA

MHA (32 KV heads):

32 × 32 × 2 × 1024 × 128 × 2 bytes = 536 MB

GQA (8 KV heads):

32 ×  8 × 2 × 1024 × 128 × 2 bytes = 134 MB

MQA (1 KV head):

32 ×  1 × 2 × 1024 × 128 × 2 bytes = 16.75 MB
MechanismKV headsKV Cachevs MHA
MHA32536 MB100%
GQA8134 MB25%
MQA116.75 MB3.1%

GQA hits 25% of MHA's memory cost while retaining close to MHA-level quality. MQA gets to 3.1% but pays a steeper quality price.

23.5.2 Quality vs Efficiency

GQA benchmark: inference time vs model quality across MHA, GQA, MQA

From the GQA paper's benchmarks:

  • GQA-G8 (8 groups) sits close to MHA in quality
  • GQA-G8 inference time is close to MQA
  • Quality improvements from adding more KV groups plateau quickly beyond 8

An important empirical finding: in trained MHA models, different heads' K and V representations are often surprisingly similar. Many heads learn near-redundant projections. That is why sharing K/V within a group loses relatively little — the diversity you give up was not providing much signal to begin with.

This finding has an architectural implication: if you are designing a model from scratch rather than converting an existing MHA checkpoint, you can train directly with GQA and the model learns to use its KV capacity efficiently from the start. The redundancy in MHA is partly an artifact of having no incentive to differentiate K/V representations across heads during training.

23.5.3 Serving Concurrency Impact

The memory numbers above directly determine how many users you can serve simultaneously. On an A100 80GB GPU with a 7B model loaded at FP16 (14 GB), roughly 66 GB remains for KV Cache:

MechanismPer-session KV at 4k ctxMax concurrent sessions
MHA536 MB × 4 = 2.1 GB~31
GQA (8 heads)134 MB × 4 = 536 MB~123
MQA16.75 MB × 4 = 67 MB~984

GQA roughly quadruples your concurrent user capacity versus MHA at the same hardware budget. That is the business case in one table.

23.5.4 The Full Tradeoff Table

MechanismQualityInference speedKV memoryWhen to use
MHAHighestSlowestLargestResearch, small models, training-only settings
MQASome lossFastestSmallestEdge/mobile, extreme throughput requirements
GQANear-MHANear-MQAMediumAlmost everything in production

23.6 What Modern Models Ship

23.6.1 The Industry Has Converged on GQA

Production models using GQA: Llama-3, Mistral, Qwen
ModelParamsQ headsKV headsGroup size
Llama-2 7B7B32321 (MHA)
Llama-2 70B70B6488
Llama-3 8B8B3284
Llama-3 70B70B6488
Mistral 7B7B3284
Qwen-1.5 7B7B32321 (MHA)
Qwen-1.5 32B32B4085
Qwen-2 7B7B2847

A few observations: smaller models sometimes stick to MHA because the absolute memory cost is manageable and the full expressiveness matters; larger models almost universally go GQA; 8 KV heads is a common sweet spot.

23.6.2 Why 8 KV Heads?

Research shows quality improves rapidly from 1 to 8 KV heads and then flattens. Meanwhile, 8 divides evenly into common tensor-parallel configurations (2, 4, or 8 GPUs), so the KV heads can be distributed cleanly across devices. It is both empirically good and operationally convenient.

23.6.3 Multi-GPU Benefit

In tensor-parallel serving (e.g., 4 GPUs):

MHA with 32 heads:

GPU 0: Q heads 0–7,  K heads 0–7,  V heads 0–7
GPU 1: Q heads 8–15, K heads 8–15, V heads 8–15
...

GQA with 32 Q heads, 8 KV heads:

GPU 0: Q heads 0–7,  K heads 0–1, V heads 0–1
GPU 1: Q heads 8–15, K heads 2–3, V heads 2–3
...

Each GPU's KV Cache is 4× smaller. This matters when running many concurrent requests.


23.7 Converting MHA Checkpoints to GQA

If you already have a trained MHA model, Google's GQA paper proposed uptraining:

  1. Average weights — for each group of K/V heads to be merged, take the mean of their projection matrices
  2. Continue training — run a short fine-tuning pass on about 5% of the original training data
  3. Recover quality — the model adapts quickly because the averaged weights are already a reasonable initialization
def convert_mha_to_gqa(k_weights, n_heads, n_kv_heads, head_dim, d_model):
    """Average groups of K (or V) projection matrices."""
    group_size = n_heads // n_kv_heads
    # k_weights shape: [d_model, n_heads * head_dim]
    k = k_weights.reshape(d_model, n_heads, head_dim)
    # Group and average: [d_model, n_kv_heads, group_size, head_dim]
    k_grouped = k.reshape(d_model, n_kv_heads, group_size, head_dim)
    k_gqa = k_grouped.mean(dim=2)  # [d_model, n_kv_heads, head_dim]
    return k_gqa.reshape(d_model, n_kv_heads * head_dim)

This works because of the empirical redundancy mentioned earlier: the averaged result is already a reasonable approximation of what each group's shared K/V should look like.

The GQA paper reports recovering most of the quality gap with just 5% of the original training data. This makes uptraining practical: you invest in training a high-quality MHA model once, then cheaply convert it to a GQA model for efficient serving.


23.8 Flash Attention and GQA Together

Flash Attention 2 added native GQA support directly in the kernel. This matters for efficiency.

Without GQA awareness, a Flash Attention implementation would need to expand K and V via repeat_kv before the tiled loop — creating a larger tensor in HBM. With native GQA, the kernel maps each Q block to its KV group index (group_idx = q_head_idx // n_rep) and loads the right K/V tile without materializing the expanded tensor.

The effect: you get both Flash Attention's IO efficiency and GQA's smaller KV footprint, with no extra memory overhead from the repeat operation. When you call F.scaled_dot_product_attention or use a GQA-aware implementation like vLLM or TensorRT-LLM, this optimization is typically applied automatically.


23.10 Common Misconceptions

"GQA is just MQA with more heads." Not exactly. GQA is a parameterized family. MHA and MQA are the two extremes; GQA is the whole spectrum between them. The key design choice is n_kv_heads.

"Fewer KV heads is always better." The quality-efficiency tradeoff is real. Going from 32 to 8 KV heads cuts memory 4× with minimal quality loss; going from 8 to 1 cuts memory another 8× but with more visible degradation. The optimal setting depends on your serving constraints and quality bar.

"GQA only affects inference." GQA also reduces parameter count slightly (smaller K and V projection matrices), which can speed up training and reduce model file size. The effect is small but not zero. For a 7B model with 32 heads going to 8 KV heads: the K and V projection matrices shrink from [d_model, d_model] to [d_model, d_model/4], saving about 6% of total parameters.

"Every model should use GQA." For very small models (sub-7B) deployed in memory-abundant settings, MHA's extra expressiveness may be worth the overhead. Always measure quality before committing to a specific n_kv_heads.

"Flash Attention and GQA conflict." They are complementary. Flash Attention 2 added native GQA support: it handles the repeat_kv internally during the tiled computation, so you get both the IO-efficient kernel and the smaller KV footprint simultaneously.


23.11 Chapter Summary

MHA (Multi-Head Attention)
  n_kv_heads = n_heads
  Each head has independent Q, K, V
  KV Cache: 2 × n_heads tensors per layer
  Best quality, highest memory

MQA (Multi-Query Attention)
  n_kv_heads = 1
  All heads share one K, one V
  KV Cache: 2 tensors per layer
  Lowest memory, quality risk at scale

GQA (Grouped-Query Attention)
  1 < n_kv_heads < n_heads
  Groups of heads share K/V
  KV Cache: 2 × n_kv_heads tensors per layer
  Near-MHA quality, near-MQA efficiency

Selection Guide

SituationRecommendationReason
Research / training focusMHAMaximum expressiveness
Large-scale production servingGQA (8 KV heads)Best quality-efficiency balance
Edge / mobile / extreme efficiencyMQAMinimum memory footprint
UncertainGQA with 8 KV headsSafe, empirically validated default

23.11.1 Reading a Model Config

The Hugging Face config.json for any modern model will list both fields. For Llama-3 8B:

{
  "num_attention_heads": 32,
  "num_key_value_heads": 8
}

num_attention_heads is n_heads (Q heads). num_key_value_heads is n_kv_heads. The group size is 32 / 8 = 4 — each KV pair is shared by 4 query heads. If num_key_value_heads equals num_attention_heads, the model uses MHA. If it equals 1, it uses MQA.


Chapter Checklist

After this chapter, you should be able to:

  • Explain why MHA's KV Cache becomes a bottleneck at long context or high concurrency.
  • Describe MHA, MQA, and GQA in one sentence each.
  • Calculate KV Cache size for each mechanism given model dimensions.
  • Explain why GQA quality loss is small despite large memory savings.
  • Read a model config and identify its n_heads and n_kv_heads.
  • Implement repeat_kv.
  • Estimate how many concurrent users a given GPU can serve under each mechanism.

Further Reading

  1. Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019) — MQA original paper
  2. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
  3. Llama 2: Open Foundation and Fine-Tuned Chat Models (Meta, 2023) — shows MHA→GQA transition across model scales

See You in the Next Chapter

GQA reduces the amount of K/V data we store. But even with a smaller cache, every token still attends to every other token within the window — the quadratic cost of full Attention remains.

Chapter 24 explores what happens when you drop that requirement entirely. Sparse Attention lets each token attend to only a chosen subset of the sequence, pushing complexity toward O(N). And Infini Attention goes further, using a fixed-size compressed memory to handle context that grows without bound.

Cite this page
Zhang, Wayland (2026). Chapter 23: From MHA to MQA to GQA. In Transformer Architecture: From Intuition to Implementation. https://waylandz.com/llm-transformer-book-en/chapter-23-mha-mqa-gqa
@incollection{zhang2026transformer_chapter_23_mha_mqa_gqa,
  author = {Zhang, Wayland},
  title = {Chapter 23: From MHA to MQA to GQA},
  booktitle = {Transformer Architecture: From Intuition to Implementation},
  year = {2026},
  url = {https://waylandz.com/llm-transformer-book-en/chapter-23-mha-mqa-gqa}
}