一文要約: MHA は各ヘッドが独自の K と V を持つ(最も表現力が高く、メモリも最大)。MQA はすべてのヘッドが一つの K/V を共有する(最小メモリだが品質がやや落ちる)。GQA はグループ共有で中間をとる。現代の LLM の大半はこの GQA を採用しています。


23.1 KV キャッシュが生み出した問題

23.1.1 メモリの計算

第22章で確認したとおり、プロダクション推論において KV キャッシュは欠かせません。しかし Multi-Head Attention では各ヘッドが独自の K と V をキャッシュするため、積み重なると大変なことになります。

典型的な 7B パラメータモデルの場合 — 32 層、32 ヘッド、ヘッド次元 128、FP16:

リクエストごとの KV キャッシュ =
    32  × 32 ヘッド × 2 (K  V) × seq_len × 128 × 2 バイト

seq_len = 1024 のとき:

32 × 32 × 2 × 1024 × 128 × 2 = 536 MB

これは 1 ユーザー、1 会話、1024 トークンで 536 MB です。100 人が同時に 4096 トークンを使うとスケールすると:

536 MB × 4 (4096/1024) × 100 ユーザー  200 GB

KV キャッシュだけで 200 GB。A100 が丸二枚つぶれる計算です。業界が「ヘッドごとに独立した K/V は本当に必要なのか」と問い始めた理由がわかります。

23.1.2 根本的なトレードオフ

MHA はもともと学習向けに設計されています。各ヘッドが異なるパターンを捉えるために独立した射影を学ぶ、それが強みです。しかし推論では、その独立した射影がコストになります。レイヤーごと、ヘッドごと、トークンごとに K/V を保存しなければなりません。

学習は表現力を求める。サービングは効率を求める。MQA と GQA はそのトレードオフから生まれたアーキテクチャです。

23.1.3 三つのメカニズム

メカニズム正式名称核心的なアイデア
MHAMulti-Head Attention各ヘッドが独立した K、V を持つ
MQAMulti-Query Attentionすべてのヘッドが一つの K、V を共有する
GQAGrouped-Query Attentionグループ単位でヘッドが一つの K/V を共有する

23.2 MHA: ベースライン

23.2.1 構造

MHA: 各ヘッドが独自の Q、K、V 射影を持つ

n_heads ヘッドの標準 MHA では、各ヘッドがそれぞれ:

  • 独自の WQ(i)W_Q^{(i)} 射影
  • 独自の WK(i)W_K^{(i)} 射影
  • 独自の WV(i)W_V^{(i)} 射影

を持ちます。KV キャッシュにはレイヤーごとに n_heads 個の K テンソルと n_heads 個の V テンソルが格納されます。

23.2.2 マルチヘッドが役立つ理由

異なるヘッドは本当に異なるものを学習します。「エージェントがレビュアーをタグ付けした、なぜなら PR が緊急だったから」という文を考えてみましょう:

  • ヘッド 1 は構文上の主語-動詞を追う: エージェント → タグ付けした
  • ヘッド 2 は代名詞の解消を追う: 「その PR」← どの PR?
  • ヘッド 3 は因果推論を追う: タグ付けした → なぜなら → 緊急
  • ヘッド 4 は直近性を追い、直近のトークンに強く注目する

独立した K/V 射影によって、各ヘッドがトークン履歴に対して独自の「視点」を構築できます。これが MHA の強みです。

23.2.3 MHA のコード形状

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads

        # すべてのヘッドを含む d_model 射影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)  # n_heads 個の独立した K 射影
        self.W_v = nn.Linear(d_model, d_model)  # n_heads 個の独立した V 射影
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
        k = self.W_k(x).view(B, T, self.n_heads, self.head_dim)
        v = self.W_v(x).view(B, T, self.n_heads, self.head_dim)
        # KV キャッシュは n_heads 組の K  V を格納する

23.2.4 プロダクションでの問題

32 ヘッドのモデルでは、各レイヤーの KV キャッシュに 64 テンソル(32 K + 32 V)が必要です。長いコンテキストや高い並行数では、これがリクエストを捌ける量の制約になります。

ツール呼び出しを組み合わせたエージェントシステムで 16k コンテキストを使う場合を具体的に計算すると:

セッションごとの KV キャッシュ (Llama-7B、16k ctx、FP16):
    32  × 32 ヘッド × 2 (K+V) × 16384 × 128 × 2 バイト  8 GB

1 アクティブセッションで 8 GB です。モデルの重みを読み込んだ後に 40 GB の GPU が残っていれば、長いコンテキストセッションをせいぜい 4 つしか並行で捌けません。チームで使うことを考えると、KV キャッシュのメモリを削減したいという圧力がよくわかります。


23.3 MQA: すべてを集約する

23.3.1 核心的なアイデア

Multi-Query Attention(Shazeer, 2019)はシンプルですが大胆な選択をします: すべてのクエリヘッドが一つの K と一つの V を共有する

MQA: 多くの Q ヘッドが一つの K と一つの V を共有する
  • Q は引き続き n_heads 個の独立した射影を持つ
  • K は 1 個の射影
  • V は 1 個の射影

KV キャッシュに格納されるのは、クエリヘッドの数にかかわらず レイヤーごとに 2 テンソルだけです。

23.3.2 メモリの節約

同じ 7B モデルで 1024 トークンの場合:

MHA KV キャッシュ = 32  × 32 ヘッド × 2 × 1024 × 128 × 2 = 536 MB
MQA KV キャッシュ = 32  ×  1 ヘッド × 2 × 1024 × 128 × 2 = 16.75 MB

97% 削減。MHA で 5 ユーザーしか捌けなかった同じ GPU が、MQA では約 160 ユーザーに対応できます。

23.3.3 MQA のコード形状

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)        # n_heads 分フル
        self.W_k = nn.Linear(d_model, self.head_dim)  # 1 ヘッドのみ!
        self.W_v = nn.Linear(d_model, self.head_dim)  # 1 ヘッドのみ!
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.head_dim)
        k = self.W_k(x).view(B, T, 1, self.head_dim)
        v = self.W_v(x).view(B, T, 1, self.head_dim)
        # k, v  [B,T,1,head_dim] から [B,T,n_heads,head_dim] にブロードキャストされる

23.3.4 コスト

すべてのクエリヘッドに同じ K/V 参照を強いることで、各ヘッドがトークン履歴の独立した視点を構築する力が制限されます。MQA は多くのタスクでうまく機能しますが、多様な長距離パターン捕捉が必要なタスクでは品質の低下が見られます。Google の PaLM は MQA を採用しましたが、フロンティアスケールでの品質低下はコミュニティに受け入れがたいものでした。


23.4 GQA: 現実的な中間点

23.4.1 核心的なアイデア

Grouped-Query Attention(Ainslie et al., 2023)はハイパーパラメータを一つ導入します: n_kv_heads、つまり K/V グループの数です。

クエリヘッドは n_kv_heads 個のグループに分けられます。同じグループ内のすべてのクエリヘッドが一つの K 射影と一つの V 射影を共有します。

GQA: Q ヘッドをグループに分け、各グループが一つの K と V を共有する

形式的には:

  • n_heads — Q ヘッドの数
  • n_kv_heads — KV グループの数
  • n_rep = n_heads / n_kv_heads — グループあたりの Q ヘッド数

特殊なケース:

  • n_kv_heads = n_heads → MHA(各ヘッドが独立)
  • n_kv_heads = 1 → MQA(すべてのヘッドが共有)
  • 1 < n_kv_heads < n_heads → GQA

23.4.2 GQA のコード形状

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.n_heads    = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep      = n_heads // n_kv_heads
        self.head_dim   = d_model // n_heads

        self.W_q = nn.Linear(d_model, n_heads    * self.head_dim)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        q = self.W_q(x).view(B, T, self.n_heads,    self.head_dim)
        k = self.W_k(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.W_v(x).view(B, T, self.n_kv_heads, self.head_dim)

        # K  V  Q のヘッド数に合わせて展開する
        k = self.repeat_kv(k)  # [B, T, n_heads, head_dim]
        v = self.repeat_kv(v)
        # ここから先は MHA と同じ Attention の計算

    def repeat_kv(self, x):
        """各 KV グループを n_rep 回繰り返して Q ヘッド数に合わせる。"""
        B, T, n_kv, head_dim = x.shape
        if self.n_rep == 1:
            return x
        # [B, T, n_kv, head_dim] -> [B, T, n_kv, n_rep, head_dim]
        x = x.unsqueeze(3).expand(B, T, n_kv, self.n_rep, head_dim)
        # -> [B, T, n_heads, head_dim]
        return x.reshape(B, T, self.n_heads, head_dim)

23.4.2b 学習と推論の違い

一つ理解しておくべき微妙な点があります。学習時は KV キャッシュを使いません(シーケンス全体をカジュアルマスキングで並列処理します)。そのため学習における GQA のメリットは、K と V の射影行列が小さくなるというパラメータ数の削減だけです。小さいですが、ゼロではありません。

推論時はメリットがずっと大きくなります。デコードはメモリ帯域幅に律速されます(第22章 22.6.2 節)。生成される各トークンは KV キャッシュ全体を HBM から読み込みます。KV キャッシュが小さければ、より多くが SRAM に収まり、トークンごとの HBM 読み込みが減り、スループットが上がります。GQA の 4× や 8× のメモリ削減は、ほぼそのままデコード速度の向上につながります。

23.4.3 repeat_kv の幾何学的イメージ

n_heads = 8n_kv_heads = 2 のとき:

元の K/V の形状: [B, T, 2, head_dim]

  KV グループ 0          KV グループ 1

repeat_kv : [B, T, 8, head_dim]

  Q0  Q1  Q2  Q3  Q4  Q5  Q6  Q7
                        
  KV0 KV0 KV0 KV0 KV1 KV1 KV1 KV1

Q ヘッド 0〜3 は KV グループ 0 を共有し、Q ヘッド 4〜7 は KV グループ 1 を共有します。計算上これはテンソルの繰り返しであり、独立した射影ではありません。パラメータは追加されません。


23.5 三者比較

23.5.1 メモリの数値

7B モデル、1024 トークンシーケンス、FP16 の場合:

KV キャッシュサイズの比較: MHA vs GQA vs MQA

MHA(32 KV ヘッド):

32 × 32 × 2 × 1024 × 128 × 2 バイト = 536 MB

GQA(8 KV ヘッド):

32 ×  8 × 2 × 1024 × 128 × 2 バイト = 134 MB

MQA(1 KV ヘッド):

32 ×  1 × 2 × 1024 × 128 × 2 バイト = 16.75 MB
メカニズムKV ヘッドKV キャッシュMHA 比
MHA32536 MB100%
GQA8134 MB25%
MQA116.75 MB3.1%

GQA は MHA のメモリコスト 25% で、MHA に近い品質を保ちます。MQA は 3.1% まで到達しますが、品質の代償がより大きくなります。

23.5.2 品質と効率のトレードオフ

GQA ベンチマーク: MHA、GQA、MQA 間の推論時間とモデル品質の比較

GQA 論文のベンチマークから:

  • GQA-G8(8 グループ)は品質が MHA に近い
  • GQA-G8 の推論時間は MQA に近い
  • 8 グループを超えて KV グループを増やしても品質改善はすぐに頭打ちになる

重要な実験的知見があります。学習済みの MHA モデルでは、異なるヘッドの K と V の表現が驚くほど似ていることが多いのです。多くのヘッドがほぼ冗長な射影を学びます。だからグループ内で K/V を共有してもそれほど損をしない — 失う多様性はもともとそれほど有用なシグナルを提供していなかったからです。

この知見にはアーキテクチャ上の示唆があります。既存の MHA チェックポイントを変換するのではなく最初から設計するなら、GQA で直接学習すれば、モデルは最初から KV 容量を効率的に使うことを学べます。MHA における冗長性は、学習中に K/V 表現をヘッド間で差別化するインセンティブがないことの副産物でもあります。

23.5.3 サービング並行数への影響

上記のメモリ数値は、同時に何人のユーザーに対応できるかを直接決定します。A100 80GB GPU で 7B モデルを FP16 でロードすると(14 GB)、KV キャッシュ用に約 66 GB 残ります:

メカニズムセッションあたり KV (4k ctx)最大並行セッション数
MHA536 MB × 4 = 2.1 GB約 31
GQA(8 ヘッド)134 MB × 4 = 536 MB約 123
MQA16.75 MB × 4 = 67 MB約 984

GQA は同じハードウェア予算で MHA に対して並行ユーザー数をほぼ 4 倍にします。これがビジネスケースを一つの表で示したものです。

23.5.4 完全なトレードオフ表

メカニズム品質推論速度KV メモリ使いどころ
MHA最高最遅最大研究、小規模モデル、学習オンリーの設定
MQAやや劣る最速最小エッジ/モバイル、極端なスループット要件
GQAMHA に近いMQA に近い中程度プロダクションのほぼすべて

23.6 現代のモデルが採用するもの

23.6.1 業界は GQA に収束している

GQA を採用するプロダクションモデル: Llama-3、Mistral、Qwen
モデルパラメータQ ヘッドKV ヘッドグループサイズ
Llama-2 7B7B32321 (MHA)
Llama-2 70B70B6488
Llama-3 8B8B3284
Llama-3 70B70B6488
Mistral 7B7B3284
Qwen-1.5 7B7B32321 (MHA)
Qwen-1.5 32B32B4085
Qwen-2 7B7B2847

いくつか観察できます。小さいモデルが MHA に留まることがあるのは、絶対的なメモリコストが管理可能で、表現力をフル活用したいからです。大きいモデルはほぼ例外なく GQA を選びます。KV ヘッド 8 というのが一般的なスイートスポットです。

23.6.2 なぜ 8 KV ヘッドなのか

研究によると、品質は 1 から 8 KV ヘッドにかけて急速に改善し、そこから頭打ちになります。一方、8 は一般的なテンソル並列構成(2、4、または 8 GPU)に均等に割り切れるため、KV ヘッドをデバイス間できれいに分散できます。経験的にも優れており、運用上も便利という組み合わせです。

23.6.3 マルチ GPU でのメリット

テンソル並列サービング(例: 4 GPU)の場合:

MHA(32 ヘッド):

GPU 0: Q ヘッド 0–7,  K ヘッド 0–7,  V ヘッド 0–7
GPU 1: Q ヘッド 8–15, K ヘッド 8–15, V ヘッド 8–15
...

GQA(32 Q ヘッド、8 KV ヘッド):

GPU 0: Q ヘッド 0–7,  K ヘッド 0–1, V ヘッド 0–1
GPU 1: Q ヘッド 8–15, K ヘッド 2–3, V ヘッド 2–3
...

各 GPU の KV キャッシュが 4× 小さくなります。多数の並行リクエストを実行するときに重要です。


23.7 MHA チェックポイントを GQA に変換する

既に学習済みの MHA モデルがあれば、Google の GQA 論文が提案した uptraining を使えます:

  1. 重みを平均化する — マージする K/V ヘッドの各グループについて、それらの射影行列の平均を取る
  2. 継続学習する — 元の学習データの約 5% で短いファインチューニングを実行する
  3. 品質を回復させる — 平均化された重みが合理的な初期化になっているため、モデルは素早く適応する
def convert_mha_to_gqa(k_weights, n_heads, n_kv_heads, head_dim, d_model):
    """K(または V)射影行列のグループを平均化する。"""
    group_size = n_heads // n_kv_heads
    # k_weights の形状: [d_model, n_heads * head_dim]
    k = k_weights.reshape(d_model, n_heads, head_dim)
    # グループ化して平均: [d_model, n_kv_heads, group_size, head_dim]
    k_grouped = k.reshape(d_model, n_kv_heads, group_size, head_dim)
    k_gqa = k_grouped.mean(dim=2)  # [d_model, n_kv_heads, head_dim]
    return k_gqa.reshape(d_model, n_kv_heads * head_dim)

これが機能するのは、先述の経験的冗長性のためです。平均化された結果はすでに、各グループの共有 K/V がどうあるべきかの合理的な近似になっています。

GQA 論文では、元の学習データのわずか 5% で品質ギャップのほとんどを回復できると報告されています。これにより uptraining が現実的になります。高品質な MHA モデルを一度学習すれば、効率的なサービング用の GQA モデルに安価に変換できます。


23.8 Flash Attention と GQA の組み合わせ

Flash Attention 2 はカーネルに直接ネイティブ GQA サポートを追加しました。これは効率の観点で重要です。

GQA を認識しない Flash Attention 実装では、タイルループの前に repeat_kv で K と V を展開する必要があり、HBM により大きなテンソルが作られます。ネイティブ GQA 対応では、カーネルが各 Q ブロックを KV グループインデックス(group_idx = q_head_idx // n_rep)にマッピングし、展開されたテンソルを実体化せずに正しい K/V タイルをロードします。

その効果: Flash Attention の IO 効率と GQA の小さな KV フットプリントの両方を、繰り返し操作による余分なメモリオーバーヘッドなしに得られます。F.scaled_dot_product_attention を呼び出すか、vLLM や TensorRT-LLM のような GQA 対応実装を使うと、この最適化は通常自動的に適用されます。


23.10 よくある誤解

「GQA はヘッド数を増やした MQA に過ぎない」 — 厳密には違います。GQA はパラメータ化されたファミリーです。MHA と MQA は両極端であり、GQA はその間のスペクトル全体です。設計上の重要な選択は n_kv_heads です。

「KV ヘッドは少ないほど常によい」 — 品質と効率のトレードオフは現実に存在します。32 から 8 KV ヘッドに減らすと、品質損失を最小限にメモリを 4× 削減できます。8 から 1 に減らすとさらに 8× 削減できますが、より目立った品質劣化が伴います。最適な設定はサービング制約と品質要件に依存します。

「GQA は推論にしか影響しない」 — GQA はパラメータ数もわずかに削減します(小さな K と V 射影行列)。これにより学習が速くなり、モデルファイルサイズも小さくなる可能性があります。その影響は小さいですが、ゼロではありません。32 ヘッドから 8 KV ヘッドに変更した 7B モデルの場合: K と V の射影行列は [d_model, d_model] から [d_model, d_model/4] に縮小し、総パラメータの約 6% が節約されます。

「すべてのモデルが GQA を使うべき」 — メモリが豊富な環境にデプロイされる非常に小さいモデル(7B 未満)では、MHA の追加表現力がオーバーヘッドに見合う場合があります。特定の n_kv_heads にコミットする前には、必ず品質を測定してください。

「Flash Attention と GQA は競合する」 — 補完関係にあります。Flash Attention 2 はネイティブ GQA サポートを追加しました。タイル計算中に repeat_kv を内部で処理するため、IO 効率の高いカーネルと小さな KV フットプリントの両方を同時に得られます。


23.11 章のまとめ

MHA (Multi-Head Attention)
  n_kv_heads = n_heads
  各ヘッドが独立した Q、K、V を持つ
  KV キャッシュ: レイヤーごとに 2 × n_heads テンソル
  最高品質、最大メモリ

MQA (Multi-Query Attention)
  n_kv_heads = 1
  すべてのヘッドが一つの K、V を共有
  KV キャッシュ: レイヤーごとに 2 テンソル
  最小メモリ、スケールでの品質リスクあり

GQA (Grouped-Query Attention)
  1 < n_kv_heads < n_heads
  グループ単位でヘッドが K/V を共有
  KV キャッシュ: レイヤーごとに 2 × n_kv_heads テンソル
  MHA に近い品質、MQA に近い効率

選択ガイド

状況推奨理由
研究・学習重視MHA最大の表現力
大規模プロダクションサービングGQA(8 KV ヘッド)最良の品質-効率バランス
エッジ/モバイル/極限の効率MQA最小メモリフットプリント
不確かな場合GQA(8 KV ヘッド)安全で実証済みのデフォルト

23.11.1 モデルの設定ファイルを読む

最新のモデルの Hugging Face config.json には、両方のフィールドが記載されています。Llama-3 8B の場合:

{
  "num_attention_heads": 32,
  "num_key_value_heads": 8
}

num_attention_headsn_heads(Q ヘッド数)で、num_key_value_headsn_kv_heads です。グループサイズは 32 / 8 = 4 — 各 KV ペアが 4 つのクエリヘッドで共有されます。num_key_value_headsnum_attention_heads と等しければ MHA、1 ならば MQA です。


チャプターチェックリスト

この章を終えた後、以下ができるようになっているはずです:

  • 長いコンテキストや高い並行数で MHA の KV キャッシュがボトルネックになる理由を説明できる。
  • MHA、MQA、GQA をそれぞれ一文で説明できる。
  • モデルの次元数が与えられたとき、各メカニズムの KV キャッシュサイズを計算できる。
  • 大きなメモリ節約にもかかわらず GQA の品質損失が小さい理由を説明できる。
  • モデルの設定ファイルを読んで n_headsn_kv_heads を特定できる。
  • repeat_kv を実装できる。
  • 各メカニズムのもとで、ある GPU が何人の並行ユーザーに対応できるかを推定できる。

参考文献

  1. Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019) — MQA 原著論文
  2. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
  3. Llama 2: Open Foundation and Fine-Tuned Chat Models (Meta, 2023) — モデルスケールを超えた MHA→GQA の移行を示す

次の章へ

GQA によって保存する K/V データを減らすことができました。しかしキャッシュが小さくなっても、各トークンはウィンドウ内のすべての他のトークンに Attention を当てます — 完全 Attention の二次コストは残ったままです。

第24章では、その要件を完全に捨てたらどうなるかを探ります。Sparse Attention はシーケンスの選ばれた一部にしか Attention を当てないようにすることで、計算量を O(N) に近づけます。そして Infini Attention はさらに踏み込み、固定サイズの圧縮メモリを使って際限なく成長するコンテキストを扱います。次章でお会いしましょう。

このページを引用する
Zhang, Wayland (2026). 第23章: MHA から MQA、そして GQA へ. In Transformer アーキテクチャ:直感から実装まで. https://waylandz.com/llm-transformer-book-ja/chapter-23-mha-mqa-gqa
@incollection{zhang2026transformer_ja_chapter_23_mha_mqa_gqa,
  author = {Zhang, Wayland},
  title = {第23章: MHA から MQA、そして GQA へ},
  booktitle = {Transformer アーキテクチャ:直感から実装まで},
  year = {2026},
  url = {https://waylandz.com/llm-transformer-book-ja/chapter-23-mha-mqa-gqa}
}