One-sentence summary: Flash Attention is exact Attention reimplemented to avoid writing the full N×N score matrix to GPU main memory — trading a little extra arithmetic for a lot less memory traffic.
21.1 Why Standard Attention Gets Slow
21.1.1 A Confusing Observation
You are training a Transformer on an NVIDIA A100 GPU, which has a theoretical throughput of 312 TFLOPS (FP16). By that number, it should fly. But as sequence length grows, training slows down sharply and OOM errors become common.
Stranger still: even when GPU memory has headroom, even when GPU utilization looks fine, Attention becomes the bottleneck.
The answer is hiding in a place most people overlook: memory bandwidth.
21.1.2 The Attention Memory Problem
Standard Attention computes:
For sequence length and per-token dimension :
- shape:
- shape:
- shape:
When , that score matrix contains elements. In FP16 that is about 32 MB — for a single head, single sample.
Scale it up to a realistic training run (32 heads, batch size 8, forward + backward):
That is 8 GB just to store the Attention score matrices. And every training step has to move all of it through GPU main memory.
21.1.3 GPU Memory Hierarchy
GPU memory is not a single flat pool. It has a hierarchy:
| Level | Name | Capacity | Bandwidth | Notes |
|---|---|---|---|---|
| On-chip | SRAM (L1/L2/shared) | ~20 MB | ~19 TB/s | Very fast, very small |
| Device | HBM (high bandwidth memory) | ~40–80 GB | ~1.5–3 TB/s | GPU main memory |
| Host | CPU DRAM | ~1 TB | ~12.8 GB/s | Large, much slower |
SRAM is roughly 20× faster than HBM.
Think of your desk (SRAM) versus a bookshelf across the room (HBM). You can work on whatever is on your desk instantly. To reach the shelf, you have to walk over, grab the book, bring it back.
Standard Attention works like this:
- Read from HBM → compute → write back to HBM
- Read from HBM → do Softmax → write back to HBM
- Read Softmax result and from HBM → compute output → write back to HBM
Every round trip through HBM burns time. That is the real bottleneck.
21.2 Standard vs Flash: The Numbers
21.2.1 What Standard PyTorch Does
# Standard Attention
def standard_attention(Q, K, V):
# Step 1: compute QK^T, result lives in HBM
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: read scores from HBM, compute softmax, write back to HBM
attention_weights = torch.softmax(scores, dim=-1)
# Step 3: optional dropout — another HBM round trip
attention_weights = dropout(attention_weights)
# Step 4: read weights and V from HBM, compute output
output = torch.matmul(attention_weights, V)
return output
On GPT-2-sized Attention:
- PyTorch standard: ~15 ms, split across Matmul, Dropout, Softmax, Mask as separate kernels
- Flash Attention: ~3 ms, fusing everything into one fused kernel
That is a 5× speedup just from eliminating HBM round trips.
21.2.2 Memory Complexity
The memory picture is even cleaner:
| Implementation | Memory for Attention matrix |
|---|---|
| Standard Attention | — full score matrix in HBM |
| Flash Attention | — only inputs and outputs, no intermediate matrix |
At Flash Attention cuts the intermediate memory by a factor of roughly 2048. At that factor doubles again.
21.3 Tiling: The Core Idea
21.3.1 The Intuition
Imagine you need to compute a very large multiplication table — say 100×100 entries. The naive approach:
- Compute the whole table, write it on a big sheet.
- Post-process each row (Softmax).
- Continue the calculation.
Flash Attention does this instead:
- Cut the big table into 10×10 tiles.
- Process one tile at a time, entirely on the scratchpad (SRAM).
- Accumulate the final result without ever writing the full table.
The sequence of operations inside SRAM for each tile:
1. Q_block @ K_block^T
2. apply causal mask
3. online Softmax (see Section 21.4)
4. optional dropout
5. multiply by V_block
6. accumulate into output O_i
Only the final output needs to go back to HBM. The giant intermediate matrix never exists in memory at all.
21.3.2 How Big Are the Tiles?
On an A100, SRAM per streaming multiprocessor is about 192 KB. We need to fit four things simultaneously:
- one block of :
- one block of :
- one block of :
- one block of output:
The block size formula:
where is SRAM size and is the model dimension. For and (the model width), that gives:
In practice this is rounded down to 64 for memory-alignment reasons, so in a typical A100 deployment.
21.3.3 The Core Loop
Algorithm: FlashAttention (simplified)
Input: Q, K, V in HBM
Output: O in HBM
for j = 1 to T_c: # outer loop over K, V blocks
load K_j, V_j into SRAM
for i = 1 to T_r: # inner loop over Q blocks
load Q_i, O_i, l_i, m_i from HBM into SRAM
S_ij = Q_i @ K_j^T # score block
update m_i (running max)
update l_i (running denominator)
O_i += rescaled_P_ij @ V_j # accumulate output
write O_i, l_i, m_i back to HBM
The "rescaled" part is the job of online Softmax.
21.4 Online Softmax: Computing Softmax Without Seeing Everything
21.4.1 The Problem
Standard Softmax:
The denominator sums over all elements. But during tiling we only see one block at a time. How do we compute Softmax correctly?
21.4.2 The Online Update Rule
We maintain three running quantities:
- : maximum value seen so far
- : vector of numerator terms (rescaled exponentials)
- : sum of numerator terms (denominator accumulator)
When a new block arrives after processing block :
1. Update the running maximum:
2. Rescale previous numerators:
3. Update the denominator:
4. Final Softmax:
21.4.3 Why Track the Maximum?
The correction factor is for numerical stability.
Computing for large overflows. The standard fix: subtract the maximum before exponentiating.
In tiled computation each block has its own local maximum. When a new block arrives and the global maximum changes, we use to retroactively correct the previous accumulation.
21.4.4 A Worked Example
Full row: . Split into two blocks.
Block 1 — :
Block 2 — :
- New global max: (unchanged)
- Correction factors: for block 1; for block 2
Softmax of first element:
Direct calculation gives 50.28%. The tiny difference is rounding from the worked example — in practice the math is exact.
21.5 The Full FlashAttention Algorithm
Algorithm: FLASHATTENTION
Input: Q, K, V ∈ R^{N×d} in HBM; on-chip SRAM of size M
1. Set B_c = ceil(M / 4d), B_r = min(ceil(M / 4d), d)
2. Initialize O = 0, l = 0, m = -∞ in HBM
3. Divide Q into T_r = ceil(N / B_r) blocks
4. Divide K, V into T_c = ceil(N / B_c) blocks
5. for j = 1 to T_c:
6. Load K_j, V_j from HBM to SRAM
7. for i = 1 to T_r:
8. Load Q_i, O_i, l_i, m_i from HBM to SRAM
9. S_ij = Q_i @ K_j^T
10. m̃_ij = rowmax(S_ij)
11. P̃_ij = exp(S_ij − m̃_ij), l̃_ij = rowsum(P̃_ij)
12. m_i_new = max(m_i, m̃_ij)
13. l_i_new = exp(m_i − m_i_new) × l_i + exp(m̃_ij − m_i_new) × l̃_ij
14. O_i = diag(l_i_new)^{-1} × (diag(l_i) × exp(m_i − m_i_new) × O_i
+ exp(m̃_ij − m_i_new) × P̃_ij × V_j)
15. Write O_i, l_i_new, m_i_new back to HBM
16. Return O
21.5.1 IO Complexity
Standard Attention:
- Write :
- Read, Softmax, write:
- Read Softmax + V, write output:
- Total HBM traffic:
Flash Attention:
- Each K/V block read once per outer loop: reads of
- Each Q/O block read/written once per inner iteration: of
- Total HBM traffic:
When , Flash Attention's IO complexity approaches — a factor of improvement over the standard path.
21.6 Flash Attention 1 vs Flash Attention 2
21.6.1 What FA1 Did
Flash Attention 1 (2022) proved the idea and delivered real speedups. Its inner loop parallelism was limited by synchronization between workers sharing the output accumulator.
21.6.2 What FA2 Added
Flash Attention 2 (2023) made three important changes:
- Better work partitioning — reduces synchronization between streaming multiprocessors and uses the hardware more uniformly
- Native MQA and GQA support — directly handles the head-sharing patterns covered in Chapter 23
- Fewer non-matmul operations — fewer register spills, cleaner pipeline
Performance on A100 80GB SXM4:
| Config | PyTorch | FA1 | FA2 |
|---|---|---|---|
| Seq 2k, head_dim 64 | ~50 TFLOPS | ~120 TFLOPS | ~175 TFLOPS |
| Seq 4k, head_dim 64 | ~45 TFLOPS | ~110 TFLOPS | ~170 TFLOPS |
| Seq 8k, head_dim 128 | ~40 TFLOPS | ~100 TFLOPS | ~165 TFLOPS |
FA2 reaches 50–70% of the A100's peak throughput on a memory-bound operation. That is very good.
21.7 Practical Usage
21.7.1 Installation
pip install flash-attn --no-build-isolation
21.7.2 Direct API
import torch
from flash_attn import flash_attn_func
batch_size, seq_len, num_heads, head_dim = 2, 4096, 32, 128
# shape: [batch, seq_len, num_heads, head_dim]
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn_like(q)
v = torch.randn_like(q)
output = flash_attn_func(q, k, v, causal=True)
21.7.3 PyTorch 2.0+ Built-in
PyTorch 2.0 added scaled_dot_product_attention, which dispatches to Flash Attention automatically when the inputs qualify:
import torch.nn.functional as F
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
)
# PyTorch picks Flash Attention, Memory Efficient Attention,
# or the standard path depending on hardware and input shape.
21.7.4 Hugging Face Transformers
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
21.7.5 Limitations Worth Knowing
Hardware: Flash Attention requires recent NVIDIA GPUs. Ampere (A100) and newer get the best results. Older cards or non-NVIDIA hardware fall back to standard paths.
Backward pass: Flash Attention does not store the full score matrix, so the backward pass has to recompute it. The extra arithmetic is cheap compared to the IO savings; end-to-end training still wins by 2–4×.
Non-standard masks: Custom Attention masks (sparse, sliding-window, arbitrary patterns) may need special handling. FA2 already supports causal masks, padding masks, and MQA/GQA out of the box.
21.8 Chapter Summary
| Concept | Key Point |
|---|---|
| Bottleneck | HBM bandwidth, not arithmetic throughput |
| SRAM vs HBM | SRAM ~19 TB/s; HBM ~1.5 TB/s; SRAM is ~20× faster |
| Tiling | Process small Q/K/V blocks in SRAM, never write the full N×N matrix |
| Online Softmax | Track running max, numerator, denominator; correct past blocks when global max updates |
| Memory complexity | Standard: ; Flash: |
| FA1 → FA2 | Better parallelism, native MQA/GQA, 1.5–2× over FA1 |
| End-to-end speedup | 2–4× training; 5× on attention kernel alone |
Chapter Checklist
After this chapter, you should be able to:
- Explain why HBM bandwidth, not TFLOPS, is the Attention bottleneck.
- Describe what tiling does and why it avoids materializing the N×N matrix.
- Walk through the Online Softmax update rule.
- State the memory complexity of standard vs Flash Attention.
- Explain why Flash Attention is exact, not approximate.
- Compare FA1 and FA2.
Further Reading
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022) — arXiv 2205.14135
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023) — arXiv 2307.08691
- Memory Efficient Attention (xFormers)
- PagedAttention (vLLM) — a complementary approach for KV Cache management
See You in the Next Chapter
Flash Attention makes each individual Attention computation cheaper. But during autoregressive generation, we still have a different problem: the model recomputes K and V for old tokens on every single step.
Chapter 22 fixes that with KV Cache, which is Flash Attention's natural partner for fast inference.