One-sentence summary: Flash Attention is exact Attention reimplemented to avoid writing the full N×N score matrix to GPU main memory — trading a little extra arithmetic for a lot less memory traffic.


21.1 Why Standard Attention Gets Slow

21.1.1 A Confusing Observation

You are training a Transformer on an NVIDIA A100 GPU, which has a theoretical throughput of 312 TFLOPS (FP16). By that number, it should fly. But as sequence length grows, training slows down sharply and OOM errors become common.

Stranger still: even when GPU memory has headroom, even when GPU utilization looks fine, Attention becomes the bottleneck.

The answer is hiding in a place most people overlook: memory bandwidth.

21.1.2 The Attention Memory Problem

Standard Attention computes:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)V

For sequence length NN and per-token dimension dd:

  • QQ shape: [N,d][N, d]
  • KK shape: [N,d][N, d]
  • QKTQK^T shape: [N,N][N, N]

When N=4096N = 4096, that score matrix contains 4096×4096=16,777,2164096 \times 4096 = 16{,}777{,}216 elements. In FP16 that is about 32 MB — for a single head, single sample.

Scale it up to a realistic training run (32 heads, batch size 8, forward + backward):

32×8×32 MB=8 GB32 \times 8 \times 32\ \text{MB} = 8\ \text{GB}

That is 8 GB just to store the Attention score matrices. And every training step has to move all of it through GPU main memory.

QK matrix visualized alongside the GPU memory hierarchy

21.1.3 GPU Memory Hierarchy

GPU memory is not a single flat pool. It has a hierarchy:

LevelNameCapacityBandwidthNotes
On-chipSRAM (L1/L2/shared)~20 MB~19 TB/sVery fast, very small
DeviceHBM (high bandwidth memory)~40–80 GB~1.5–3 TB/sGPU main memory
HostCPU DRAM~1 TB~12.8 GB/sLarge, much slower

SRAM is roughly 20× faster than HBM.

Think of your desk (SRAM) versus a bookshelf across the room (HBM). You can work on whatever is on your desk instantly. To reach the shelf, you have to walk over, grab the book, bring it back.

Standard Attention works like this:

  1. Read Q,KQ, K from HBM → compute QKTQK^Twrite back to HBM
  2. Read QKTQK^T from HBM → do Softmax → write back to HBM
  3. Read Softmax result and VV from HBM → compute output → write back to HBM

Every round trip through HBM burns time. That is the real bottleneck.


21.2 Standard vs Flash: The Numbers

21.2.1 What Standard PyTorch Does

# Standard Attention
def standard_attention(Q, K, V):
    # Step 1: compute QK^T, result lives in HBM
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: read scores from HBM, compute softmax, write back to HBM
    attention_weights = torch.softmax(scores, dim=-1)

    # Step 3: optional dropout  another HBM round trip
    attention_weights = dropout(attention_weights)

    # Step 4: read weights and V from HBM, compute output
    output = torch.matmul(attention_weights, V)
    return output
PyTorch vs FlashAttention throughput benchmark

On GPT-2-sized Attention:

  • PyTorch standard: ~15 ms, split across Matmul, Dropout, Softmax, Mask as separate kernels
  • Flash Attention: ~3 ms, fusing everything into one fused kernel

That is a 5× speedup just from eliminating HBM round trips.

21.2.2 Memory Complexity

The memory picture is even cleaner:

ImplementationMemory for Attention matrix
Standard AttentionO(N2)O(N^2) — full score matrix in HBM
Flash AttentionO(N)O(N) — only inputs and outputs, no intermediate matrix

At N=2048N = 2048 Flash Attention cuts the intermediate memory by a factor of roughly 2048. At N=4096N = 4096 that factor doubles again.


21.3 Tiling: The Core Idea

21.3.1 The Intuition

Imagine you need to compute a very large multiplication table — say 100×100 entries. The naive approach:

  1. Compute the whole table, write it on a big sheet.
  2. Post-process each row (Softmax).
  3. Continue the calculation.

Flash Attention does this instead:

  1. Cut the big table into 10×10 tiles.
  2. Process one tile at a time, entirely on the scratchpad (SRAM).
  3. Accumulate the final result without ever writing the full table.
SRAM tiling: small blocks of K, Q, V are loaded and processed entirely in SRAM

The sequence of operations inside SRAM for each tile:

1. Q_block @ K_block^T
2. apply causal mask
3. online Softmax (see Section 21.4)
4. optional dropout
5. multiply by V_block
6. accumulate into output O_i

Only the final output OO needs to go back to HBM. The giant intermediate N×NN \times N matrix never exists in memory at all.

21.3.2 How Big Are the Tiles?

On an A100, SRAM per streaming multiprocessor is about 192 KB. We need to fit four things simultaneously:

  • one block of QQ: Br×dB_r \times d
  • one block of KK: Bc×dB_c \times d
  • one block of VV: Bc×dB_c \times d
  • one block of output: Br×dB_r \times d

The block size formula:

min(Br,Bc)=M4d\min(B_r, B_c) = \left\lceil\frac{M}{4d}\right\rceil

where MM is SRAM size and dd is the model dimension. For M=192 KB=192×1024 bytesM = 192\text{ KB} = 192 \times 1024 \text{ bytes} and d=512d = 512 (the model width), that gives:

192×10244×512=196,6082,048=96=96\left\lceil\frac{192 \times 1024}{4 \times 512}\right\rceil = \left\lceil\frac{196{,}608}{2{,}048}\right\rceil = \lceil 96 \rceil = 96

In practice this is rounded down to 64 for memory-alignment reasons, so Br=Bc=64B_r = B_c = 64 in a typical A100 deployment.

A100 streaming multiprocessors and block size derivation

21.3.3 The Core Loop

Algorithm: FlashAttention (simplified)

Input:  Q, K, V in HBM
Output: O in HBM

for j = 1 to T_c:          # outer loop over K, V blocks
    load K_j, V_j into SRAM

    for i = 1 to T_r:      # inner loop over Q blocks
        load Q_i, O_i, l_i, m_i from HBM into SRAM

        S_ij = Q_i @ K_j^T             # score block
        update m_i (running max)
        update l_i (running denominator)
        O_i += rescaled_P_ij @ V_j     # accumulate output

        write O_i, l_i, m_i back to HBM

The "rescaled" part is the job of online Softmax.


21.4 Online Softmax: Computing Softmax Without Seeing Everything

21.4.1 The Problem

Standard Softmax:

softmax(x)i=exij=1Kexj\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{K} e^{x_j}}
Online Softmax tracks running max and denominator across blocks

The denominator sums over all elements. But during tiling we only see one block at a time. How do we compute Softmax correctly?

21.4.2 The Online Update Rule

We maintain three running quantities:

  • m(x)m(x): maximum value seen so far
  • f(x)f(x): vector of numerator terms (rescaled exponentials)
  • l(x)l(x): sum of numerator terms (denominator accumulator)

When a new block x(2)x^{(2)} arrives after processing block x(1)x^{(1)}:

1. Update the running maximum:

m(x)=max ⁣(m(x(1)),m(x(2)))m(x) = \max\!\bigl(m(x^{(1)}),\, m(x^{(2)})\bigr)

2. Rescale previous numerators:

f(x)=[em(x(1))m(x)f(x(1)),    em(x(2))m(x)f(x(2))]f(x) = \Bigl[e^{m(x^{(1)}) - m(x)} \cdot f(x^{(1)}),\;\; e^{m(x^{(2)}) - m(x)} \cdot f(x^{(2)})\Bigr]

3. Update the denominator:

l(x)=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))l(x) = e^{m(x^{(1)}) - m(x)} \cdot l(x^{(1)}) + e^{m(x^{(2)}) - m(x)} \cdot l(x^{(2)})

4. Final Softmax:

softmax(x)=f(x)l(x)\text{softmax}(x) = \frac{f(x)}{l(x)}

21.4.3 Why Track the Maximum?

The emoldmnewe^{m_\text{old} - m_\text{new}} correction factor is for numerical stability.

Computing exe^x for large xx overflows. The standard fix: subtract the maximum before exponentiating.

softmax(x)i=eximax(x)jexjmax(x)\text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_{j} e^{x_j - \max(x)}}

In tiled computation each block has its own local maximum. When a new block arrives and the global maximum changes, we use emoldmnewe^{m_\text{old} - m_\text{new}} to retroactively correct the previous accumulation.

21.4.4 A Worked Example

Full row: [3.01,  0.09,  2.48,  1.95][3.01,\; 0.09,\; 2.48,\; 1.95]. Split into two blocks.

Block 1[3.01,  0.09][3.01,\; 0.09]:

  • m(1)=3.01m^{(1)} = 3.01
  • f(1)=[e0,  e2.92]=[1,  0.053]f^{(1)} = [e^{0},\; e^{-2.92}] = [1,\; 0.053]
  • l(1)=1.053l^{(1)} = 1.053

Block 2[2.48,  1.95][2.48,\; 1.95]:

  • m(2)=2.48m^{(2)} = 2.48
  • New global max: m=max(3.01,2.48)=3.01m = \max(3.01, 2.48) = 3.01 (unchanged)
  • Correction factors: e3.013.01=1e^{3.01 - 3.01} = 1 for block 1; e2.483.01=0.59e^{2.48 - 3.01} = 0.59 for block 2
  • l=1×1.053+0.59×(1+e0.53)1.99l = 1 \times 1.053 + 0.59 \times (1 + e^{-0.53}) \approx 1.99

Softmax of first element:

softmax(3.01)=11.9950.25%\text{softmax}(3.01) = \frac{1}{1.99} \approx 50.25\%

Direct calculation gives 50.28%. The tiny difference is rounding from the worked example — in practice the math is exact.

Block-by-block accumulation visualization

21.5 The Full FlashAttention Algorithm

Algorithm: FLASHATTENTION

Input:  Q, K, V  R^{N×d} in HBM; on-chip SRAM of size M

1.  Set B_c = ceil(M / 4d),  B_r = min(ceil(M / 4d), d)
2.  Initialize O = 0, l = 0, m = - in HBM

3.  Divide Q into T_r = ceil(N / B_r) blocks
4.  Divide K, V into T_c = ceil(N / B_c) blocks

5.  for j = 1 to T_c:
6.      Load K_j, V_j from HBM to SRAM
7.      for i = 1 to T_r:
8.          Load Q_i, O_i, l_i, m_i from HBM to SRAM
9.          S_ij  = Q_i @ K_j^T
10.         m̃_ij  = rowmax(S_ij)
11.         P̃_ij  = exp(S_ij  m̃_ij),  l̃_ij = rowsum(P̃_ij)
12.         m_i_new = max(m_i, m̃_ij)
13.         l_i_new = exp(m_i  m_i_new) × l_i + exp(m̃_ij  m_i_new) × l̃_ij
14.         O_i     = diag(l_i_new)^{-1} × (diag(l_i) × exp(m_i  m_i_new) × O_i
                         + exp(m̃_ij  m_i_new) × P̃_ij × V_j)
15.         Write O_i, l_i_new, m_i_new back to HBM
16. Return O
FlashAttention algorithm annotated with SRAM and HBM data flow

21.5.1 IO Complexity

Standard Attention:

  • Write S=QKTS = QK^T: O(N2)O(N^2)
  • Read, Softmax, write: O(N2)O(N^2)
  • Read Softmax + V, write output: O(N2+Nd)O(N^2 + Nd)
  • Total HBM traffic: O(N2+Nd)O(N^2 + Nd)

Flash Attention:

  • Each K/V block read once per outer loop: TcT_c reads of O(Bcd)O(B_c d)
  • Each Q/O block read/written once per inner iteration: Tr×TcT_r \times T_c of O(Brd)O(B_r d)
  • Total HBM traffic: O ⁣(N2d2M)O\!\left(\frac{N^2 d^2}{M}\right)

When MdM \gg d, Flash Attention's IO complexity approaches O(N2d/M)O(N^2 d / M) — a factor of M/dM/d improvement over the standard path.


21.6 Flash Attention 1 vs Flash Attention 2

FA1 vs FA2 parallelism strategy and throughput comparison

21.6.1 What FA1 Did

Flash Attention 1 (2022) proved the idea and delivered real speedups. Its inner loop parallelism was limited by synchronization between workers sharing the output accumulator.

21.6.2 What FA2 Added

Flash Attention 2 (2023) made three important changes:

  1. Better work partitioning — reduces synchronization between streaming multiprocessors and uses the hardware more uniformly
  2. Native MQA and GQA support — directly handles the head-sharing patterns covered in Chapter 23
  3. Fewer non-matmul operations — fewer register spills, cleaner pipeline

Performance on A100 80GB SXM4:

ConfigPyTorchFA1FA2
Seq 2k, head_dim 64~50 TFLOPS~120 TFLOPS~175 TFLOPS
Seq 4k, head_dim 64~45 TFLOPS~110 TFLOPS~170 TFLOPS
Seq 8k, head_dim 128~40 TFLOPS~100 TFLOPS~165 TFLOPS

FA2 reaches 50–70% of the A100's peak throughput on a memory-bound operation. That is very good.


21.7 Practical Usage

21.7.1 Installation

pip install flash-attn --no-build-isolation

21.7.2 Direct API

import torch
from flash_attn import flash_attn_func

batch_size, seq_len, num_heads, head_dim = 2, 4096, 32, 128

# shape: [batch, seq_len, num_heads, head_dim]
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
                dtype=torch.float16, device='cuda')
k = torch.randn_like(q)
v = torch.randn_like(q)

output = flash_attn_func(q, k, v, causal=True)

21.7.3 PyTorch 2.0+ Built-in

PyTorch 2.0 added scaled_dot_product_attention, which dispatches to Flash Attention automatically when the inputs qualify:

import torch.nn.functional as F

output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True,
)
# PyTorch picks Flash Attention, Memory Efficient Attention,
# or the standard path depending on hardware and input shape.

21.7.4 Hugging Face Transformers

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
)

21.7.5 Limitations Worth Knowing

Hardware: Flash Attention requires recent NVIDIA GPUs. Ampere (A100) and newer get the best results. Older cards or non-NVIDIA hardware fall back to standard paths.

Backward pass: Flash Attention does not store the full score matrix, so the backward pass has to recompute it. The extra arithmetic is cheap compared to the IO savings; end-to-end training still wins by 2–4×.

Non-standard masks: Custom Attention masks (sparse, sliding-window, arbitrary patterns) may need special handling. FA2 already supports causal masks, padding masks, and MQA/GQA out of the box.


21.8 Chapter Summary

ConceptKey Point
BottleneckHBM bandwidth, not arithmetic throughput
SRAM vs HBMSRAM ~19 TB/s; HBM ~1.5 TB/s; SRAM is ~20× faster
TilingProcess small Q/K/V blocks in SRAM, never write the full N×N matrix
Online SoftmaxTrack running max, numerator, denominator; correct past blocks when global max updates
Memory complexityStandard: O(N2)O(N^2); Flash: O(N)O(N)
FA1 → FA2Better parallelism, native MQA/GQA, 1.5–2× over FA1
End-to-end speedup2–4× training; 5× on attention kernel alone

Chapter Checklist

After this chapter, you should be able to:

  • Explain why HBM bandwidth, not TFLOPS, is the Attention bottleneck.
  • Describe what tiling does and why it avoids materializing the N×N matrix.
  • Walk through the Online Softmax update rule.
  • State the memory complexity of standard vs Flash Attention.
  • Explain why Flash Attention is exact, not approximate.
  • Compare FA1 and FA2.

Further Reading

  • FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022) — arXiv 2205.14135
  • FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023) — arXiv 2307.08691
  • Memory Efficient Attention (xFormers)
  • PagedAttention (vLLM) — a complementary approach for KV Cache management

See You in the Next Chapter

Flash Attention makes each individual Attention computation cheaper. But during autoregressive generation, we still have a different problem: the model recomputes K and V for old tokens on every single step.

Chapter 22 fixes that with KV Cache, which is Flash Attention's natural partner for fast inference.

Cite this page
Zhang, Wayland (2026). Chapter 21: Flash Attention - Memory-Aware Attention. In Transformer Architecture: From Intuition to Implementation. https://waylandz.com/llm-transformer-book-en/chapter-21-flash-attention
@incollection{zhang2026transformer_chapter_21_flash_attention,
  author = {Zhang, Wayland},
  title = {Chapter 21: Flash Attention - Memory-Aware Attention},
  booktitle = {Transformer Architecture: From Intuition to Implementation},
  year = {2026},
  url = {https://waylandz.com/llm-transformer-book-en/chapter-21-flash-attention}
}