一句话总结:MHA 让每个注意力头都有独立的 K、V,表达能力最强但内存开销巨大;MQA 让所有头共享同一组 K、V,内存最省但质量下降;GQA 是折中方案,分组共享 K、V,在内存效率和模型质量之间找到了最佳平衡点,已成为 Llama-3、Mistral、Qwen 等主流模型的标配。


23.1 为什么需要改进 MHA?

23.1.1 KV Cache 的内存困境

在第 21 章和第 22 章,我们深入讨论了 KV Cache。它的核心思想是:推理时缓存已计算的 K、V,避免重复计算

但这带来一个严重的问题:内存爆炸

让我们算一笔账。对于一个典型的 7B 参数模型:

KV Cache 大小 = 层数 × KV头数 × 2(K和V) × 序列长度 × 每个头维度 × 精度

假设:

  • 32 层
  • 32 个注意力头
  • 序列长度 1024
  • 每个头维度 128
  • 使用 FP16(2 bytes)
32 × 32 × 2 × 1024 × 128 × 2 bytes = 536 MB

536 MB,这只是一个用户的一次对话!

如果你要同时服务 100 个用户,每个用户对话 4096 tokens?

536 MB × 4(4096/1024)× 100 用户  200 GB

这就是 ~200 GB 的 KV Cache!这就是为什么 KV Cache 优化如此重要。

23.1.2 问题的本质

MHA(Multi-Head Attention)的设计理念是:让每个头都能学习不同的表示

这意味着每个头都有独立的:

  • Query 投影矩阵 WQW_Q
  • Key 投影矩阵 WKW_K
  • Value 投影矩阵 WVW_V

在训练时,这没问题——计算完就丢掉。

但在推理时,为了实现 KV Cache,我们必须保存每一层、每个头的 K 和 V。

核心矛盾:MHA 的设计初衷是追求表达能力,但在部署阶段,这种"奢侈"的设计导致了严重的内存瓶颈。

23.1.3 本章要解决的问题

我们将探讨三种注意力机制:

机制全称核心思想
MHAMulti-Head Attention每个头独立的 K、V
MQAMulti-Query Attention所有头共享一组 K、V
GQAGrouped-Query Attention分组共享 K、V

23.2 MHA:标准多头注意力回顾

23.2.1 MHA 的结构

MHA 是原始 Transformer 论文《Attention Is All You Need》提出的标准结构。

MHA 结构图

如图所示,在 MHA 中:

  • 每个头都有独立的 Q、K、V
  • 如果有 n_heads 个头,就有 n_heads 套 K 和 n_heads 套 V
  • KV Cache 需要存储 n_heads * 2 个张量(每层)

23.2.2 MHA 的优点

表达能力强:每个头可以学习完全不同的 attention pattern。

比如,在处理句子 "The cat sat on the mat because it was tired" 时:

  • Head 1 可能专注于主语-谓语关系(cat -> sat)
  • Head 2 可能专注于代词指代(it -> cat)
  • Head 3 可能专注于空间关系(sat -> on -> mat)
  • Head 4 可能专注于因果关系(sat -> because -> tired)

不同的 K、V 投影让每个头可以构建完全不同的"关注视角"。

23.2.3 MHA 的问题

内存开销巨大,尤其在推理阶段。

我们再看一次 KV Cache 的公式:

# MHA  KV Cache 大小(每层)
kv_cache_size = n_heads * 2 * seq_len * head_dim * bytes_per_element

对于 32 个头的模型,每层要存储 64 个张量(32 个 K + 32 个 V)。

这在以下场景会成为严重瓶颈:

  • 长上下文:处理 100K tokens 的文档
  • 高并发:同时服务数百个用户
  • 端侧部署:手机/边缘设备的内存有限

23.2.4 MHA 代码示例

# 标准 MHA:每个头独立的 K、V
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        # 每个头都有独立的 Q、K、V 投影
        self.W_q = nn.Linear(d_model, d_model)  # 包含所有头的 Q
        self.W_k = nn.Linear(d_model, d_model)  # 包含所有头的 K
        self.W_v = nn.Linear(d_model, d_model)  # 包含所有头的 V
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape

        # Q: [B, T, n_heads, head_dim]
        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
        # K: [B, T, n_heads, head_dim] - 每个头独立的 K
        k = self.W_k(x).view(B, T, self.n_heads, self.head_dim)
        # V: [B, T, n_heads, head_dim] - 每个头独立的 V
        v = self.W_v(x).view(B, T, self.n_heads, self.head_dim)

        # KV Cache 需要存储 n_heads  K  V
        # ...

23.3 MQA:多查询注意力

23.3.1 MQA 的核心思想

2019 年,Google 在论文《Fast Transformer Decoding: One Write-Head is All You Need》中提出了 MQA。

核心思想极其简单:所有的 Query 头共享同一组 Key 和 Value

MQA 结构图

如图所示:

  • Q 仍然有多个头(图中 Head 1-4 的绿色方块)
  • 但 K 和 V 只有一个(共享)
  • KV Cache 只需要存储 n_heads + 1 个张量(n 个 Q + 1 个 K + 1 个 V,实际是 2 个张量因为 Q 不需要缓存)

等等,图上写的是 "n_heads + 1",这里需要澄清:

实际上 KV Cache 只存储 K 和 V,不存储 Q(Q 每次都要重新计算)。所以:

  • MHA 的 KV Cache:n_heads 个 K + n_heads 个 V = 2 * n_heads
  • MQA 的 KV Cache:1 个 K + 1 个 V = 2

23.3.2 MQA 的直觉理解

想象你是一位老师(模型),面对一个教室的学生(tokens)。

MHA 是这样的:

  • 你有 8 个分身(8 个头)
  • 每个分身都带着自己的笔记本(独立的 K、V)
  • 每个分身用自己的方式记录每个学生的特点

MQA 是这样的:

  • 你有 8 个分身(8 个头)
  • 但所有分身共用一本笔记本(共享的 K、V)
  • 每个分身可以提不同的问题(独立的 Q)
  • 但他们查阅的是同一份资料

23.3.3 MQA 的优势

内存节省惊人

回到我们之前的 7B 模型例子:

MHA KV Cache = 32  × 32  × 2 × 1024 × 128 × 2 = 536 MB
MQA KV Cache = 32  × 1  × 2 × 1024 × 128 × 2 = 16.75 MB

从 536 MB 降到 16.75 MB,减少了 97%!

这意味着:

  • 同样的 GPU 内存可以支持 32 倍的并发用户
  • 或者支持 32 倍长的上下文
  • 或者在内存受限的设备上运行

23.3.4 MQA 的代价

天下没有免费的午餐。MQA 的代价是模型质量下降

为什么?因为所有头被迫使用相同的 K、V。

回到我们的例子 "The cat sat on the mat because it was tired":

  • Head 1 想关注 cat -> sat 关系
  • Head 2 想关注 it -> cat 关系
  • Head 3 想关注 sat -> mat 关系

但在 MQA 中,它们必须查阅同一份 K、V"参考资料"。这限制了每个头捕捉不同模式的能力。

23.3.5 MQA 代码示例

# MQA:所有头共享 K、V
class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        # Q 仍然是完整的 d_model 维度(包含所有头)
        self.W_q = nn.Linear(d_model, d_model)
        # K、V 只有一个头的维度!
        self.W_k = nn.Linear(d_model, self.head_dim)  # 只有 head_dim
        self.W_v = nn.Linear(d_model, self.head_dim)  # 只有 head_dim
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape

        # Q: [B, T, n_heads, head_dim] - 每个头独立
        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)

        # K, V: [B, T, 1, head_dim] - 只有一份!
        k = self.W_k(x).view(B, T, 1, self.head_dim)
        v = self.W_v(x).view(B, T, 1, self.head_dim)

        # 计算 attention 时,K、V 会广播到所有头
        # k: [B, T, 1, head_dim] -> broadcast to [B, T, n_heads, head_dim]
        # ...

23.3.6 MQA 的实际应用

Google 的 PaLM(Pathways Language Model)系列采用了 MQA。

但业界发现:对于追求最高质量的模型,MQA 的质量损失难以接受

于是,GQA 应运而生。


23.4 GQA:分组查询注意力

23.4.1 GQA 的核心思想

GQA(Grouped-Query Attention)是 Google 在 2023 年论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》中提出的。

核心思想:把 Q 头分成若干组,每组共享一套 K、V

GQA 结构图

如图所示:

  • Q 仍然有多个头(Head 1-4)
  • 但 K、V 不再是每个头独立,也不是全部共享
  • 而是分组共享:Head 1-2 共享一组 K、V,Head 3-4 共享另一组
  • KV Cache 需要存储 2 * n_heads / g 个张量(g 是组数)

23.4.2 GQA 的数学表达

设:

  • n_heads = Q 头的数量
  • n_kv_heads = KV 组的数量
  • g = n_heads / n_kv_heads(每组包含的 Q 头数量)

则:

  • MHA:n_kv_heads = n_heads(每个 Q 头一组 KV)
  • MQA:n_kv_heads = 1(所有 Q 头共享一组 KV)
  • GQA:1 < n_kv_heads < n_heads(折中方案)

常见配置:

  • Llama-2 70B:n_heads=64, n_kv_heads=8(每 8 个 Q 头共享一组 KV)
  • Llama-3 8B:n_heads=32, n_kv_heads=8(每 4 个 Q 头共享一组 KV)
  • Mistral 7B:n_heads=32, n_kv_heads=8(每 4 个 Q 头共享一组 KV)

23.4.3 GQA 的直觉理解

继续用老师和教室的比喻:

GQA 是这样的:

  • 你有 8 个分身(8 个头)
  • 分身被分成 2 组,每组 4 人
  • 每组共用一本笔记本(组内共享 K、V)
  • 组与组之间的笔记本不同

这样:

  • 同一组内的分身可以协作(共享信息)
  • 不同组之间仍有差异(不同的参考资料)
  • 内存开销大大减少(只需要 2 本笔记本,而不是 8 本)

23.4.4 三种机制对比

三种机制对比

上图完美展示了三种机制的对比:

架构对比(上半部分):

  • Multi-head:8 个 Q 头配 8 个 K 头配 8 个 V 头
  • Grouped-query:8 个 Q 头配 2 个 K 头配 2 个 V 头
  • Multi-query:8 个 Q 头配 1 个 K 头配 1 个 V 头

性能对比(左下图):

  • 横轴是推理时间(Time per sample)
  • 纵轴是模型性能(Performance)
  • GQA-XXL 在性能上接近 MHA-XXL
  • 但推理时间接近 MQA-XXL

组数影响(右下图):

  • 横轴是 GQA 的组数
  • 组数越多(接近 MHA),推理越慢
  • 组数越少(接近 MQA),推理越快
  • 8 组是一个甜点

GQA 的核心价值:用接近 MQA 的推理效率,获得接近 MHA 的模型质量。

23.4.5 KV Cache 内存对比

KV Cache 大小对比

让我们用具体数字来说话。对于一个 7B 模型,1024 tokens 序列:

MHA(32 个 KV 头):

32  × 32 KV头 × 128 head_dim × 2(K和V) × 2(FP16) × 1024 tokens
= 536 MB

GQA(8 个 KV 头):

32  × 8 KV头 × 128 head_dim × 2(K和V) × 2(FP16) × 1024 tokens
= 134 MB

MQA(1 个 KV 头):

32  × 1 KV头 × 128 head_dim × 2(K和V) × 2(FP16) × 1024 tokens
= 16.75 MB
机制KV 头数KV Cache 大小相对 MHA
MHA32536 MB100%
GQA8134 MB25%
MQA116.75 MB3.1%

GQA 将内存降低了 75%,同时保持了接近 MHA 的质量!

23.4.6 GQA 代码示例

# GQA:分组共享 K、V
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.n_heads = n_heads        # Q 头数量
        self.n_kv_heads = n_kv_heads  # KV 组数量
        self.n_rep = n_heads // n_kv_heads  # 每组包含的 Q 头数
        self.head_dim = d_model // n_heads

        # Q 有完整的 n_heads 个头
        self.W_q = nn.Linear(d_model, n_heads * self.head_dim)
        # K、V 只有 n_kv_heads 个头
        self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape

        # Q: [B, T, n_heads, head_dim]
        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)

        # K, V: [B, T, n_kv_heads, head_dim]
        k = self.W_k(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.W_v(x).view(B, T, self.n_kv_heads, self.head_dim)

        # 关键:将 K、V 扩展以匹配 Q 的头数
        # 方法是重复 n_rep 
        k = self.repeat_kv(k)  # [B, T, n_heads, head_dim]
        v = self.repeat_kv(v)  # [B, T, n_heads, head_dim]

        # 后面的 attention 计算与标准 MHA 相同
        # ...

    def repeat_kv(self, x):
        """将 KV 头重复以匹配 Q 头数量"""
        B, T, n_kv_heads, head_dim = x.shape
        if self.n_rep == 1:
            return x
        # [B, T, n_kv_heads, head_dim] -> [B, T, n_kv_heads, n_rep, head_dim]
        x = x.unsqueeze(3).expand(B, T, n_kv_heads, self.n_rep, head_dim)
        # -> [B, T, n_heads, head_dim]
        return x.reshape(B, T, self.n_heads, head_dim)

23.4.7 repeat_kv 的几何理解

repeat_kv 函数是 GQA 的核心操作。让我们用一个具体例子理解它:

假设 n_heads=8, n_kv_heads=2(每 4 个 Q 头共享一组 KV):

原始 K/V 形状: [B, T, 2, head_dim]
           
     K/V  0    K/V  1
                  

扩展后形状: [B, T, 8, head_dim]
           
   Q头0  Q头1  Q头2  Q头3  Q头4  Q头5  Q头6  Q头7
                                 
   KV0   KV0  KV0  KV0   KV1  KV1  KV1  KV1

Q 头 0-3 使用 KV 组 0,Q 头 4-7 使用 KV 组 1。


23.5 三种机制的性能权衡

23.5.1 质量 vs 效率的帕累托前沿

在机器学习中,我们经常面临质量和效率的权衡。理想情况是找到帕累托最优点——在不牺牲一方的前提下无法改善另一方。

机制模型质量推理速度内存占用适用场景
MHA最高最慢最大训练、研究、不计成本的部署
MQA有损失最快最小端侧部署、极致效率需求
GQA接近 MHA接近 MQA较小生产部署、平衡需求

23.5.2 为什么 GQA 质量损失小?

一个有趣的研究发现:在 MHA 中,不同头的 K、V 表示存在高度相似性

换句话说,虽然每个头理论上可以学习完全不同的表示,但实际上:

  • 很多头学到了相似的模式
  • K、V 的多样性不如 Q 那么重要
  • 共享 K、V 并不会导致严重的信息损失

这就是为什么 GQA 可以在大幅减少 KV 头的同时,保持接近 MHA 的质量。

23.5.3 训练 vs 推理的考量

有一个关键点需要理解:

训练阶段

  • 不使用 KV Cache(因为要计算所有位置)
  • MHA、MQA、GQA 的计算成本差异不大
  • 参数量:GQA < MHA(K、V 投影矩阵更小)

推理阶段

  • 必须使用 KV Cache(否则效率太低)
  • MHA 的内存瓶颈凸显
  • GQA 的优势充分发挥

这就是为什么 GQA 成为了部署友好型架构的首选。


23.6 主流模型的选择

23.6.1 GQA 已成为主流

使用 GQA 的主流模型

如图所示,GQA 已经成为现代大模型的主流选择:

  • Llama-3(Meta)
  • Mistral-7B(Mistral AI)
  • Qwen-1.5-32B / Qwen-2(阿里)
  • 以及更多...

23.6.2 各模型的具体配置

模型参数量Q 头数KV 头数组大小
Llama-2 7B7B32321 (MHA)
Llama-2 70B70B6488
Llama-3 8B8B3284
Llama-3 70B70B6488
Mistral 7B7B3284
Qwen-1.5 7B7B32321 (MHA)
Qwen-1.5 32B32B4085
Qwen-2 7B7B2847

有趣的观察:

  • 小模型(7B)有时仍使用 MHA(如 Llama-2 7B、Qwen-1.5 7B)
  • 大模型几乎都使用 GQA
  • KV 头数通常是 8,这是一个经验上的甜点

23.6.3 为什么 8 个 KV 头?

研究表明,KV 头数从 1 增加到 8 时,模型质量显著提升;但从 8 增加到更多时,边际收益递减。

同时,8 这个数字对 GPU 的张量并行(Tensor Parallelism)也很友好——可以均匀分配到 2、4、8 个 GPU 上。


23.7 从 MHA 转换到 GQA

23.7.1 Uptraining 技术

一个实用的问题:如果已经有一个训练好的 MHA 模型,如何转换成 GQA?

Google 的 GQA 论文提出了 Uptraining 方法:

  1. 权重平均:将需要合并的 K、V 头的权重取平均
  2. 继续训练:用原始数据的一小部分(如 5%)继续训练
  3. 恢复质量:模型快速适应新的结构
def convert_mha_to_gqa(mha_weights, n_heads, n_kv_heads):
    """将 MHA 权重转换为 GQA 权重"""
    group_size = n_heads // n_kv_heads

    # 原始 K 权重: [d_model, n_heads * head_dim]
    # 重塑为: [d_model, n_heads, head_dim]
    k_weights = mha_weights['W_k'].reshape(d_model, n_heads, head_dim)

    # 分组平均: [d_model, n_kv_heads, head_dim]
    k_grouped = k_weights.reshape(d_model, n_kv_heads, group_size, head_dim)
    k_gqa = k_grouped.mean(dim=2)  # 组内平均

    # V 权重同理
    # ...

    return gqa_weights

23.7.2 为什么 Uptraining 有效?

回到我们之前的发现:MHA 中不同头的 K、V 表示存在相似性。

这意味着取平均后:

  • 保留了共同的重要信息
  • 丢失的主要是头间的细微差异
  • 少量的继续训练可以快速弥补

23.8 实现细节与优化

23.8.1 Flash Attention 与 GQA

Flash Attention(我们在第 21 章讨论过)与 GQA 是完美搭档:

  • Flash Attention 优化计算效率
  • GQA 优化内存效率

现代框架(如 vLLM、TensorRT-LLM)都同时支持这两种优化。

23.8.2 张量并行中的 GQA

在多 GPU 推理中,GQA 有额外的好处:

MHA

GPU 0: Q  0-7,  K  0-7,  V  0-7
GPU 1: Q  8-15, K  8-15, V  8-15
...

GQA(n_heads=32, n_kv_heads=8):

GPU 0: Q  0-7,  K  0-1, V  0-1
GPU 1: Q  8-15, K  2-3, V  2-3
...

每个 GPU 需要存储的 KV Cache 减少了 4 倍!

23.8.3 Prefill vs Decode 的影响

回顾 LLM 推理的两个阶段:

Prefill(首次处理 prompt):

  • 计算密集型
  • GQA 减少的参数量 → 略微加速

Decode(逐个生成 token):

  • 内存带宽密集型
  • GQA 减少的 KV Cache → 显著加速

GQA 在 Decode 阶段的收益更大。


23.9 常见误区

23.9.1 误区一:GQA 就是 MQA 加了几个头

错误理解:GQA 只是 MQA 的简单扩展。

正确理解:GQA 是一种参数化的设计,n_kv_heads 是可调的超参数。MHA 和 MQA 是 GQA 的两个极端情况。

23.9.2 误区二:KV 头越少越好

错误理解:既然减少 KV 头可以省内存,那就用最少的。

正确理解:KV 头数与模型质量存在权衡。过少的 KV 头会导致显著的质量下降。8 是一个经验上的良好选择。

23.9.3 误区三:GQA 只影响推理

错误理解:GQA 只是推理优化,与训练无关。

正确理解:GQA 减少了 K、V 投影矩阵的参数量,会影响训练。但由于训练不使用 KV Cache,影响较小。

23.9.4 误区四:所有模型都应该用 GQA

错误理解:GQA 全面优于 MHA。

正确理解:对于小模型(如 7B 以下),MHA 的内存开销可能可以接受,而其更强的表达能力可能更重要。选择取决于具体的部署约束。


23.10 本章要点回顾

23.10.1 三种机制一览

MHA (Multi-Head Attention)
    每个头独立的 Q、K、V
    KV Cache: n_heads * 2 个张量/
    质量最高,内存最大

MQA (Multi-Query Attention)
    所有头共享 K、V,Q 独立
    KV Cache: 2 个张量/
    内存最小,质量有损

GQA (Grouped-Query Attention)
    Q 头分组,组内共享 K、V
    KV Cache: n_kv_heads * 2 个张量/
    平衡质量与效率

23.10.2 选择指南

场景推荐原因
研究/训练MHA最大表达能力
大规模部署GQA平衡质量与效率
端侧/极致效率MQA最小内存占用
不确定GQA (8 KV 头)稳妥的默认选择

23.10.3 核心认知

从 MHA 到 MQA 到 GQA 的演进,体现了深度学习工程化的一个核心主题:在模型质量和部署效率之间寻找最佳平衡。GQA 通过分组共享 K、V,用 25% 的内存换取接近 100% 的质量,成为了现代大模型的主流选择。


本章交付物

学完这一章,你应该能够:

  • 解释 MHA、MQA、GQA 的核心区别
  • 计算三种机制的 KV Cache 大小
  • 理解 GQA 为什么能在质量和效率间取得平衡
  • 知道主流模型(Llama-3、Mistral、Qwen)的 GQA 配置
  • 实现 GQA 的 repeat_kv 操作

延伸阅读

  1. Fast Transformer Decoding: One Write-Head is All You Need(MQA 原论文,2019)
  2. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints(GQA 原论文,2023)
  3. Llama 2: Open Foundation and Fine-Tuned Chat Models(Meta,2023)

下一章预告

我们已经了解了如何通过共享 K、V 来优化内存。但还有另一个角度:稀疏注意力(Sparse Attention)

不是让所有 token 都关注所有其他 token,而是有选择地关注——这不仅节省内存,还能处理超长序列。

下一章,我们来探索 Sparse Attention 的各种变体:Sliding Window、Dilated、Longformer 等。

引用本文 / Cite
Zhang, W. (2026). 第 23 章:MHA 到 MQA 到 GQA 演进. In Transformer 架构:从直觉到实现. https://waylandz.com/llm-transformer-book/第23章-MHA到MQA到GQA演进
@incollection{zhang2026transformer_第23章_MHA到MQA到GQA演进,
  author = {Zhang, Wayland},
  title = {第 23 章:MHA 到 MQA 到 GQA 演进},
  booktitle = {Transformer 架构:从直觉到实现},
  year = {2026},
  url = {https://waylandz.com/llm-transformer-book/第23章-MHA到MQA到GQA演进}
}