Flash Attention: How One Paper Made Long Context Possible
The Kitchen Analogy
Imagine a chef preparing a banquet for 1,000 guests. The standard approach: bring every ingredient from the warehouse to the kitchen counter, prepare all dishes, then serve. The counter (GPU SRAM) is tiny — most ingredients sit in the warehouse (GPU HBM), and the chef spends most of their time running back and forth.
Flash Attention’s approach: bring ingredients in batches. Prepare one table’s dishes at a time. Return ingredients, get the next batch. The chef never needs a massive counter — they just work in tiles.
The key insight of Flash Attention isn’t doing less computation. It’s doing less memory movement. The compute is identical. The IO is dramatically reduced.
GPU Memory Hierarchy
Modern GPUs have two main memory tiers:
| Memory | Size | Bandwidth | Latency |
|---|---|---|---|
| SRAM (on-chip) | 20–40 MB | ~19 TB/s | ~ns |
| HBM (off-chip) | 80–141 GB | ~3 TB/s | ~100ns |
SRAM is 6× faster but 4,000× smaller. Most GPU operations are bottlenecked by how fast they can move data between HBM and SRAM — this is called being memory-bound or IO-bound.
Standard Attention: The Memory Problem
Standard attention computes:
The naive implementation:
def standard_attention(Q, K, V):
"""
Standard attention — materializes the full n×n matrix.
Memory: O(n²) for the attention matrix
HBM reads/writes: O(n²d + n²) — dominated by n²
"""
d = Q.shape[-1]
# Step 1: Compute S = QK^T — writes n×n matrix to HBM
S = Q @ K.T / (d ** 0.5) # Read Q, K from HBM → O(nd)
# Write S to HBM → O(n²)
# Step 2: Softmax — reads S from HBM, writes A to HBM
A = softmax(S, dim=-1) # Read S → O(n²), Write A → O(n²)
# Step 3: Multiply by V — reads A from HBM
output = A @ V # Read A → O(n²), Read V → O(nd)
# Write output → O(nd)
return output
# Total HBM reads/writes: O(n²) — QUADRATIC in sequence lengthFor and FP16:
This single matrix exceeds the SRAM capacity by 1,000×, forcing constant HBM access.
Flash Attention: The IO-Aware Solution
Flash Attention (Dao, 2022) never creates the full matrix. Instead, it computes attention in tiles that fit in SRAM:
The Tiling Strategy
Divide Q into blocks of size and K, V into blocks of size :
For each pair , compute a small attention tile entirely in SRAM.
The Online Softmax Trick
The challenge: softmax needs the global maximum across all keys, but we’re processing one tile at a time. How do we normalize correctly?
The solution is online softmax — computing softmax incrementally:
This maintains running statistics ( for max, for sum) that allow exact softmax computation without seeing all keys at once.
Pseudocode
def flash_attention(Q, K, V, block_size_r=128, block_size_c=128):
"""
Flash Attention: IO-aware exact attention.
Key: never materializes the n×n attention matrix.
Computes in SRAM-sized tiles with online softmax.
Memory: O(n) — no n×n matrix
HBM IO: O(n²d²/M) where M = SRAM size
Compute: O(n²d) — same as standard attention
"""
n, d = Q.shape
O = zeros(n, d) # Output accumulator
m = full(n, -inf) # Running max (for stable softmax)
l = zeros(n) # Running sum (for softmax denominator)
# Tile over K, V blocks (outer loop)
for j in range(0, n, block_size_c):
Kj = K[j:j+block_size_c] # Load tile from HBM to SRAM
Vj = V[j:j+block_size_c]
# Tile over Q blocks (inner loop)
for i in range(0, n, block_size_r):
Qi = Q[i:i+block_size_r] # Load tile from HBM to SRAM
# --- Everything below happens in SRAM ---
# Compute local attention scores
Sij = Qi @ Kj.T / sqrt(d) # Shape: (Br × Bc)
# Online softmax update
m_new = maximum(m[i:i+block_size_r], Sij.max(dim=-1))
# Correction factor for previous accumulation
correction = exp(m[i:i+block_size_r] - m_new)
# New exponentials
P = exp(Sij - m_new.unsqueeze(-1))
# Update running sum
l_new = correction * l[i:i+block_size_r] + P.sum(dim=-1)
# Update output: rescale old output + add new contribution
O[i:i+block_size_r] = (
correction.unsqueeze(-1) * l[i:i+block_size_r].unsqueeze(-1)
* O[i:i+block_size_r]
+ P @ Vj
) / l_new.unsqueeze(-1)
# Update statistics
m[i:i+block_size_r] = m_new
l[i:i+block_size_r] = l_new
# --- End SRAM computation ---
return OThe IO Analysis
Standard Attention
The term dominates — reading and writing the full attention matrix.
Flash Attention
Where is the SRAM size. Since (SRAM is ~20MB, is typically elements = ~32KB):
For typical values ():
The reduction grows with . In practice, Flash Attention achieves 2–4× speedup over standard attention.
Flash Attention 2 and 3
Flash Attention 2 (2023)
Key improvements:
- Better work partitioning between GPU thread blocks
- Reduced non-matmul FLOPs
- Achieves 50–73% of theoretical max FLOPS (vs. 25–40% for FA1)
Flash Attention 3 (2024)
Targets Hopper GPUs (H100):
- Uses asynchronous memory operations (TMA)
- Leverages FP8 tensor cores for 2× throughput
- Block quantization of KV cache during attention
- Achieves up to 740 TFLOPS on H100 (75% utilization)
Impact on Context Length
Flash Attention’s memory savings directly enabled longer context:
| Without Flash Attention | With Flash Attention |
|---|---|
| Attention matrix stored: | No attention matrix: |
| 128K context → 30 GB just for attention | 128K context → ~0 extra for attention |
| Maximum practical context: ~4K | Maximum practical context: 1M+ |
The jump from 4K to 128K+ context windows in 2023–2024 was largely enabled by Flash Attention.
When Flash Attention Doesn’t Help
Flash Attention optimizes memory IO, not compute. If your bottleneck is raw computation (enough GPU memory but not enough FLOPS), Flash Attention provides less benefit.
Specifically:
- Short sequences (): Overhead of tiling exceeds IO savings
- Already compute-bound workloads: IO isn’t the bottleneck
- Training with very long sequences: Compute () eventually dominates IO regardless
For most production inference with , Flash Attention is essential and should always be used.
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Private Code Context retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai