一文要約: Flash Attention は、GPU のメインメモリに N×N のスコアマトリクスを書き出すことを避けるために再実装された、完全に正確な Attention です。少し余分な演算をする代わりに、大幅なメモリ転送削減を実現します。
21.1 なぜ標準 Attention は遅くなるのか
21.1.1 不思議な現象
NVIDIA A100 GPU で Transformer を学習しているとします。A100 の理論スループットは 312 TFLOPS(FP16)です。この数字を見れば、猛烈に速いはずです。ところが、シーケンス長が伸びると学習が急激に遅くなり、OOM エラーも頻発するようになります。
さらに奇妙なことに、GPU のメモリに余裕があるときでも、GPU 使用率が問題なさそうなときでも、Attention がボトルネックになります。
答えは、多くの人が見落とすところに隠れています。メモリ帯域幅です。
21.1.2 Attention のメモリ問題
標準 Attention は次のように計算されます:
シーケンス長を 、トークンあたりの次元数を とすると:
- の形状:
- の形状:
- の形状:
のとき、スコアマトリクスには 個の要素が入ります。FP16 では約 32 MB — ヘッド1つ、サンプル1つ分です。
現実的な学習ランに合わせてスケールアップすると(32ヘッド、バッチサイズ8、順伝播+逆伝播):
Attention のスコアマトリクスだけで 8 GB です。そして、学習の毎ステップでこのデータを GPU のメインメモリ経由で動かさなければなりません。
21.1.3 GPU のメモリ階層
GPU のメモリは、単一のフラットなプールではありません。階層構造を持っています:
| レベル | 名称 | 容量 | 帯域幅 | 備考 |
|---|---|---|---|---|
| オンチップ | SRAM(L1/L2/共有メモリ) | 〜20 MB | 〜19 TB/s | 非常に速い、非常に小さい |
| デバイス | HBM(高帯域幅メモリ) | 〜40〜80 GB | 〜1.5〜3 TB/s | GPU メインメモリ |
| ホスト | CPU DRAM | 〜1 TB | 〜12.8 GB/s | 大容量だが大幅に遅い |
SRAM は HBM の 約20倍速です。
机の上(SRAM)と、部屋の向こうにある本棚(HBM)を想像してください。机の上にあるものはすぐに使えます。本棚の本を使うには、歩いて取りに行き、持ち帰らなければなりません。
標準 Attention はこのように動作します:
- を HBM から読み込む → を計算する → HBM に書き戻す
- を HBM から読み込む → Softmax を計算する → HBM に書き戻す
- Softmax の結果と を HBM から読み込む → 出力を計算する → HBM に書き戻す
HBM への往復ごとに時間が消費されます。これが本当のボトルネックです。
21.2 標準 Attention と Flash Attention:数字で比較する
21.2.1 標準 PyTorch の動作
# 標準 Attention
def standard_attention(Q, K, V):
# ステップ1: QK^T を計算し、結果を HBM に保存する
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# ステップ2: scores を HBM から読み込み、softmax を計算し、HBM に書き戻す
attention_weights = torch.softmax(scores, dim=-1)
# ステップ3: オプションの Dropout — さらに HBM の往復が発生する
attention_weights = dropout(attention_weights)
# ステップ4: weights と V を HBM から読み込み、出力を計算する
output = torch.matmul(attention_weights, V)
return output
GPT-2 サイズの Attention での比較:
- PyTorch 標準:〜15 ms。Matmul・Dropout・Softmax・Mask が個別カーネルとして分散実行
- Flash Attention:〜3 ms。すべてを1つの fused kernel に統合
HBM の往復を排除するだけで、5倍の高速化を実現しています。
21.2.2 メモリ計算量
メモリの比較はさらに明確です:
| 実装 | Attention マトリクスのメモリ |
|---|---|
| 標準 Attention | — 完全なスコアマトリクスを HBM に保持 |
| Flash Attention | — 入力と出力のみ、中間マトリクスなし |
のとき、Flash Attention は中間メモリを約 2048分の1 に削減します。 ではその比率がさらに倍になります。
21.3 タイリング:コアアイデア
21.3.1 直感的な理解
100×100 の巨大な掛け算の表を計算する必要があるとします。ナイーブなアプローチ:
- 表全体を計算し、大きな紙に書き出す。
- 各行を後処理する(Softmax)。
- 計算を続ける。
Flash Attention はこう考えます:
- 大きな表を 10×10 のタイルに切り分ける。
- スクラッチパッド(SRAM)上で、1つのタイルを丸ごと処理する。
- 表全体を書き出すことなく、最終結果を積み上げる。
各タイルに対して SRAM 内で行う処理の手順:
1. Q_block @ K_block^T
2. 因果マスクを適用する
3. オンライン Softmax(21.4節参照)
4. オプションの Dropout
5. V_block と掛け算する
6. 出力 O_i に積み上げる
HBM に戻す必要があるのは最終出力 だけです。巨大な中間 マトリクスは、メモリ上に一切存在しません。
21.3.2 タイルのサイズはどのくらいか
A100 では、ストリーミングマルチプロセッサあたりの SRAM は約 192 KB です。同時に4つのものを収める必要があります:
- の1ブロック:
- の1ブロック:
- の1ブロック:
- 出力の1ブロック:
ブロックサイズの式:
ここで は SRAM サイズ、 はモデルの次元数です。、(モデルの幅)とすると:
実際にはメモリアライメントの都合で 64 に切り下げるため、典型的な A100 実装では になります。
21.3.3 コアループ
アルゴリズム: FlashAttention(簡略版)
入力: Q, K, V を HBM に
出力: O を HBM に
for j = 1 to T_c: # K, V ブロックのアウターループ
K_j, V_j を SRAM に読み込む
for i = 1 to T_r: # Q ブロックのインナーループ
Q_i, O_i, l_i, m_i を HBM から SRAM に読み込む
S_ij = Q_i @ K_j^T # スコアブロック
m_i を更新(実行中の最大値)
l_i を更新(実行中の分母)
O_i += rescaled_P_ij @ V_j # 出力を積み上げる
O_i, l_i, m_i を HBM に書き戻す
「rescaled」の部分を担当するのが、オンライン Softmax です。
21.4 オンライン Softmax:全体を見ずに Softmax を計算する
21.4.1 問題
標準 Softmax:
分母はすべての要素を合計します。ところがタイリング中は、1度に1ブロックしか見えません。どうやって正確に Softmax を計算すればよいでしょうか?
21.4.2 オンライン更新ルール
3つの実行中の値を管理します:
- :これまでに見た最大値
- :分子項のベクトル(スケール補正済みの指数)
- :分子項の合計(分母の累積値)
ブロック を処理した後に新しいブロック が来たとき:
1. 実行中の最大値を更新する:
2. 過去の分子をスケール補正する:
3. 分母を更新する:
4. 最終 Softmax:
21.4.3 なぜ最大値を追跡するのか
の補正係数は、数値安定性のためです。
大きな に対して を計算するとオーバーフローします。標準的な対処法は、指数計算の前に最大値を引くことです:
タイリング計算では、各ブロックがそれぞれのローカル最大値を持ちます。新しいブロックが来てグローバル最大値が更新されたとき、 を使ってそれまでの累積値を遡って補正します。
21.4.4 計算例
完全な行:。2つのブロックに分割します。
ブロック1 — :
ブロック2 — :
- 新しいグローバル最大値:(変わらず)
- 補正係数:ブロック1は ;ブロック2は
最初の要素の Softmax:
直接計算すると 50.28% になります。わずかな差は例題の丸めによるもので、実際の計算は厳密に正確です。
21.5 FlashAttention の完全なアルゴリズム
アルゴリズム: FLASHATTENTION
入力: Q, K, V ∈ R^{N×d} を HBM に;オンチップ SRAM のサイズ M
1. B_c = ceil(M / 4d)、B_r = min(ceil(M / 4d), d) を設定する
2. O = 0、l = 0、m = -∞ を HBM に初期化する
3. Q を T_r = ceil(N / B_r) 個のブロックに分割する
4. K, V を T_c = ceil(N / B_c) 個のブロックに分割する
5. for j = 1 to T_c:
6. K_j, V_j を HBM から SRAM に読み込む
7. for i = 1 to T_r:
8. Q_i, O_i, l_i, m_i を HBM から SRAM に読み込む
9. S_ij = Q_i @ K_j^T
10. m̃_ij = rowmax(S_ij)
11. P̃_ij = exp(S_ij − m̃_ij)、l̃_ij = rowsum(P̃_ij)
12. m_i_new = max(m_i, m̃_ij)
13. l_i_new = exp(m_i − m_i_new) × l_i + exp(m̃_ij − m_i_new) × l̃_ij
14. O_i = diag(l_i_new)^{-1} × (diag(l_i) × exp(m_i − m_i_new) × O_i
+ exp(m̃_ij − m_i_new) × P̃_ij × V_j)
15. O_i, l_i_new, m_i_new を HBM に書き戻す
16. O を返す
21.5.1 IO 計算量
標準 Attention:
- を書き込む:
- 読み込み、Softmax、書き込み:
- Softmax + V を読み込み、出力を書き込む:
- HBM トラフィック合計:
Flash Attention:
- 各 K/V ブロックをアウターループで1回読み込む: 回の 読み込み
- 各 Q/O ブロックをインナーループで読み書きする: 回の
- HBM トラフィック合計:
のとき、Flash Attention の IO 計算量は に近づきます。標準パスに対して 倍の改善です。
21.6 Flash Attention 1 と Flash Attention 2
21.6.1 FA1 が達成したこと
Flash Attention 1(2022年)はこのアイデアを証明し、実際の高速化を実現しました。インナーループの並列化は、出力アキュムレータを共有するワーカー間の同期によって制約されていました。
21.6.2 FA2 が追加したこと
Flash Attention 2(2023年)は3つの重要な変更を加えました:
- より良いワーク分割 — ストリーミングマルチプロセッサ間の同期を減らし、ハードウェアをより均一に活用する
- MQA と GQA のネイティブサポート — 第23章で扱うヘッド共有パターンを直接処理する
- 非行列積演算の削減 — レジスタのスピルが減り、パイプラインがクリーンになる
A100 80GB SXM4 でのパフォーマンス:
| 設定 | PyTorch | FA1 | FA2 |
|---|---|---|---|
| シーケンス長 2k、head_dim 64 | 〜50 TFLOPS | 〜120 TFLOPS | 〜175 TFLOPS |
| シーケンス長 4k、head_dim 64 | 〜45 TFLOPS | 〜110 TFLOPS | 〜170 TFLOPS |
| シーケンス長 8k、head_dim 128 | 〜40 TFLOPS | 〜100 TFLOPS | 〜165 TFLOPS |
FA2 は、メモリバウンドな演算において A100 のピークスループットの 50〜70% に達します。これは非常に優秀な数字です。
21.7 実際の使い方
21.7.1 インストール
pip install flash-attn --no-build-isolation
21.7.2 直接 API を使う
import torch
from flash_attn import flash_attn_func
batch_size, seq_len, num_heads, head_dim = 2, 4096, 32, 128
# 形状: [batch, seq_len, num_heads, head_dim]
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn_like(q)
v = torch.randn_like(q)
output = flash_attn_func(q, k, v, causal=True)
21.7.3 PyTorch 2.0+ の組み込み機能
PyTorch 2.0 は scaled_dot_product_attention を追加しました。入力が条件を満たす場合、自動的に Flash Attention にディスパッチされます:
import torch.nn.functional as F
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
)
# PyTorch はハードウェアと入力の形状に応じて、
# Flash Attention、Memory Efficient Attention、
# または標準パスを自動選択します。
21.7.4 Hugging Face Transformers
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
21.7.5 知っておくべき制限事項
ハードウェア:Flash Attention には最新の NVIDIA GPU が必要です。Ampere(A100)以降が最良の結果を出します。古いカードや NVIDIA 以外のハードウェアは標準パスにフォールバックします。
逆伝播:Flash Attention は完全なスコアマトリクスを保存しないため、逆伝播でそれを再計算する必要があります。余分な演算はIO の節約に比べると安価で、エンドツーエンドの学習では依然として 2〜4倍のアドバンテージがあります。
非標準マスク:カスタムの Attention マスク(スパース、スライディングウィンドウ、任意のパターン)には特別な処理が必要な場合があります。FA2 はすでに因果マスク・パディングマスク・MQA/GQA をすぐに使える形でサポートしています。
21.8 章のまとめ
| 概念 | ポイント |
|---|---|
| ボトルネック | 演算スループットではなく、HBM 帯域幅 |
| SRAM vs HBM | SRAM 〜19 TB/s;HBM 〜1.5 TB/s;SRAM は約20倍速 |
| タイリング | Q/K/V の小さなブロックを SRAM で処理し、N×N マトリクスを一切書き出さない |
| オンライン Softmax | 実行中の最大値・分子・分母を追跡し、グローバル最大値が更新されたら過去のブロックを補正する |
| メモリ計算量 | 標準:;Flash: |
| FA1 → FA2 | より良い並列化、MQA/GQA のネイティブサポート、FA1 比 1.5〜2倍 |
| エンドツーエンドの高速化 | 学習で 2〜4倍;Attention カーネル単体で 5倍 |
章末チェックリスト
この章を終えた後、以下のことができるようになっているはずです:
- TFLOPS ではなく HBM 帯域幅が Attention のボトルネックである理由を説明できる。
- タイリングが何をするものか、なぜ N×N マトリクスのマテリアライズを回避できるかを説明できる。
- オンライン Softmax の更新ルールを順を追って説明できる。
- 標準 Attention と Flash Attention のメモリ計算量を言える。
- Flash Attention が近似ではなく厳密であることを説明できる。
- FA1 と FA2 を比較できる。
参考文献
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022) — arXiv 2205.14135
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023) — arXiv 2307.08691
- Memory Efficient Attention (xFormers)
- PagedAttention (vLLM) — KV Cache 管理のための補完的アプローチ
次章へ
Flash Attention は個々の Attention 計算をより安価にします。しかし、自己回帰的な生成では、別の問題が残っています。モデルは毎ステップ、古いトークンの K と V を再計算し続けています。
第22章では KV Cache でその問題を解決します。これは Flash Attention の自然なパートナーであり、高速な推論のための2本柱のうちの1本です。