Flash Attention never materializes the full n×n attention matrix. Instead, it computes in tiles using fast GPU SRAM. Here's how it works and why it's 2-4× faster.
Imagine a chef preparing a banquet for 1,000 guests. The standard approach: bring every ingredient from the warehouse to the kitchen counter, prepare all dishes, then serve. The counter (GPU SRAM) is tiny — most ingredients sit in the warehouse (GPU HBM), and the chef spends most of their time running back and forth.
Flash Attention’s approach: bring ingredients in batches. Prepare one table’s dishes at a time. Return ingredients, get the next batch. The chef never needs a massive counter — they just work in tiles.
The key insight of Flash Attention isn’t doing less computation. It’s doing less memory movement. The compute is identical. The IO is dramatically reduced.
Modern GPUs have two main memory tiers:
| Memory | Size | Bandwidth | Latency |
|---|---|---|---|
| SRAM (on-chip) | 20–40 MB | ~19 TB/s | ~ns |
| HBM (off-chip) | 80–141 GB | ~3 TB/s | ~100ns |
SRAM is 6× faster but 4,000× smaller. Most GPU operations are bottlenecked by how fast they can move data between HBM and SRAM — this is called being memory-bound or IO-bound.
Standard attention computes:
The naive implementation:
def standard_attention(Q, K, V):
"""
Standard attention — materializes the full n×n matrix.
Memory: O(n²) for the attention matrix
HBM reads/writes: O(n²d + n²) — dominated by n²
"""
d = Q.shape[-1]
# Step 1: Compute S = QK^T — writes n×n matrix to HBM
S = Q @ K.T / (d ** 0.5) # Read Q, K from HBM → O(nd)
# Write S to HBM → O(n²)
# Step 2: Softmax — reads S from HBM, writes A to HBM
A = softmax(S, dim=-1) # Read S → O(n²), Write A → O(n²)
# Step 3: Multiply by V — reads A from HBM
output = A @ V # Read A → O(n²), Read V → O(nd)
# Write output → O(nd)
return output
# Total HBM reads/writes: O(n²) — QUADRATIC in sequence lengthFor and FP16:
This single matrix exceeds the SRAM capacity by 1,000×, forcing constant HBM access.
Flash Attention (Dao, 2022) never creates the full matrix. Instead, it computes attention in tiles that fit in SRAM:
Divide Q into blocks of size and K, V into blocks of size :
For each pair , compute a small attention tile entirely in SRAM.
The challenge: softmax needs the global maximum across all keys, but we’re processing one tile at a time. How do we normalize correctly?
The solution is online softmax — computing softmax incrementally:
This maintains running statistics ( for max, for sum) that allow exact softmax computation without seeing all keys at once.
def flash_attention(Q, K, V, block_size_r=128, block_size_c=128):
"""
Flash Attention: IO-aware exact attention.
Key: never materializes the n×n attention matrix.
Computes in SRAM-sized tiles with online softmax.
Memory: O(n) — no n×n matrix
HBM IO: O(n²d²/M) where M = SRAM size
Compute: O(n²d) — same as standard attention
"""
n, d = Q.shape
O = zeros(n, d) # Output accumulator
m = full(n, -inf) # Running max (for stable softmax)
l = zeros(n) # Running sum (for softmax denominator)
# Tile over K, V blocks (outer loop)
for j in range(0, n, block_size_c):
Kj = K[j:j+block_size_c] # Load tile from HBM to SRAM
Vj = V[j:j+block_size_c]
# Tile over Q blocks (inner loop)
for i in range(0, n, block_size_r):
Qi = Q[i:i+block_size_r] # Load tile from HBM to SRAM
# --- Everything below happens in SRAM ---
# Compute local attention scores
Sij = Qi @ Kj.T / sqrt(d) # Shape: (Br × Bc)
# Online softmax update
m_new = maximum(m[i:i+block_size_r], Sij.max(dim=-1))
# Correction factor for previous accumulation
correction = exp(m[i:i+block_size_r] - m_new)
# New exponentials
P = exp(Sij - m_new.unsqueeze(-1))
# Update running sum
l_new = correction * l[i:i+block_size_r] + P.sum(dim=-1)
# Update output: rescale old output + add new contribution
O[i:i+block_size_r] = (
correction.unsqueeze(-1) * l[i:i+block_size_r].unsqueeze(-1)
* O[i:i+block_size_r]
+ P @ Vj
) / l_new.unsqueeze(-1)
# Update statistics
m[i:i+block_size_r] = m_new
l[i:i+block_size_r] = l_new
# --- End SRAM computation ---
return OThe term dominates — reading and writing the full attention matrix.
Where is the SRAM size. Since (SRAM is ~20MB, is typically elements = ~32KB):
For typical values ():
The reduction grows with . In practice, Flash Attention achieves 2–4× speedup over standard attention.
Key improvements:
Targets Hopper GPUs (H100):
Flash Attention’s memory savings directly enabled longer context:
| Without Flash Attention | With Flash Attention |
|---|---|
| Attention matrix stored: | No attention matrix: |
| 128K context → 30 GB just for attention | 128K context → ~0 extra for attention |
| Maximum practical context: ~4K | Maximum practical context: 1M+ |
The jump from 4K to 128K+ context windows in 2023–2024 was largely enabled by Flash Attention.
Flash Attention optimizes memory IO, not compute. If your bottleneck is raw computation (enough GPU memory but not enough FLOPS), Flash Attention provides less benefit.
Specifically:
For most production inference with , Flash Attention is essential and should always be used.
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai