Flash Attention: How One Paper Made Long Context Possible

Flash Attention never materializes the full n×n attention matrix. Instead, it computes in tiles using fast GPU SRAM. Here's how it works and why it's 2-4× faster.

Flash Attention: How One Paper Made Long Context Possible

Flash Attention: How One Paper Made Long Context Possible

The Kitchen Analogy

Imagine a chef preparing a banquet for 1,000 guests. The standard approach: bring every ingredient from the warehouse to the kitchen counter, prepare all dishes, then serve. The counter (GPU SRAM) is tiny — most ingredients sit in the warehouse (GPU HBM), and the chef spends most of their time running back and forth.

Flash Attention’s approach: bring ingredients in batches. Prepare one table’s dishes at a time. Return ingredients, get the next batch. The chef never needs a massive counter — they just work in tiles.

The key insight of Flash Attention isn’t doing less computation. It’s doing less memory movement. The compute is identical. The IO is dramatically reduced.

GPU Memory Hierarchy

Modern GPUs have two main memory tiers:

MemorySizeBandwidthLatency
SRAM (on-chip)20–40 MB~19 TB/s~ns
HBM (off-chip)80–141 GB~3 TB/s~100ns

SRAM is 6× faster but 4,000× smaller. Most GPU operations are bottlenecked by how fast they can move data between HBM and SRAM — this is called being memory-bound or IO-bound.

Standard Attention: The Memory Problem

Standard attention computes:

A=softmax(QKTd)VA = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

The naive implementation:

def standard_attention(Q, K, V):
    """
    Standard attention — materializes the full n×n matrix.

    Memory: O(n²) for the attention matrix
    HBM reads/writes: O(n²d + n²) — dominated by n²
    """
    d = Q.shape[-1]

    # Step 1: Compute S = QK^T — writes n×n matrix to HBM
    S = Q @ K.T / (d ** 0.5)           # Read Q, K from HBM → O(nd)
                                        # Write S to HBM → O(n²)

    # Step 2: Softmax — reads S from HBM, writes A to HBM
    A = softmax(S, dim=-1)              # Read S → O(n²), Write A → O(n²)

    # Step 3: Multiply by V — reads A from HBM
    output = A @ V                       # Read A → O(n²), Read V → O(nd)
                                        # Write output → O(nd)
    return output

    # Total HBM reads/writes: O(n²) — QUADRATIC in sequence length

For n=128,000n = 128{,}000 and FP16:

Attention matrix size=n2×2=128,0002×230.5 GB\text{Attention matrix size} = n^2 \times 2 = 128{,}000^2 \times 2 \approx 30.5 \text{ GB}

This single matrix exceeds the SRAM capacity by 1,000×, forcing constant HBM access.

Flash Attention: The IO-Aware Solution

Flash Attention (Dao, 2022) never creates the full n×nn \times n matrix. Instead, it computes attention in tiles that fit in SRAM:

The Tiling Strategy

Divide Q into blocks of size BrB_r and K, V into blocks of size BcB_c:

Q=[Q1,Q2,,QTr]where Tr=n/BrQ = [Q_1, Q_2, \ldots, Q_{T_r}] \quad \text{where } T_r = \lceil n/B_r \rceil

K=[K1,K2,,KTc]where Tc=n/BcK = [K_1, K_2, \ldots, K_{T_c}] \quad \text{where } T_c = \lceil n/B_c \rceil

For each pair (Qi,Kj)(Q_i, K_j), compute a small attention tile SijRBr×BcS_{ij} \in \mathbb{R}^{B_r \times B_c} entirely in SRAM.

The Online Softmax Trick

The challenge: softmax needs the global maximum across all keys, but we’re processing one tile at a time. How do we normalize correctly?

The solution is online softmax — computing softmax incrementally:

mi(j)=max(mi(j1),max(Sij))m_i^{(j)} = \max(m_i^{(j-1)}, \max(S_{ij}))

i(j)=emi(j1)mi(j)i(j1)+rowsum(eSijmi(j))\ell_i^{(j)} = e^{m_i^{(j-1)} - m_i^{(j)}} \cdot \ell_i^{(j-1)} + \text{rowsum}(e^{S_{ij} - m_i^{(j)}})

Oi(j)=emi(j1)mi(j)i(j1)Oi(j1)+eSijmi(j)Vji(j)O_i^{(j)} = \frac{e^{m_i^{(j-1)} - m_i^{(j)}} \cdot \ell_i^{(j-1)} \cdot O_i^{(j-1)} + e^{S_{ij} - m_i^{(j)}} \cdot V_j}{\ell_i^{(j)}}

This maintains running statistics (mm for max, \ell for sum) that allow exact softmax computation without seeing all keys at once.

Pseudocode

def flash_attention(Q, K, V, block_size_r=128, block_size_c=128):
    """
    Flash Attention: IO-aware exact attention.

    Key: never materializes the n×n attention matrix.
    Computes in SRAM-sized tiles with online softmax.

    Memory: O(n) — no n×n matrix
    HBM IO: O(n²d²/M) where M = SRAM size
    Compute: O(n²d) — same as standard attention
    """
    n, d = Q.shape
    O = zeros(n, d)       # Output accumulator
    m = full(n, -inf)     # Running max (for stable softmax)
    l = zeros(n)          # Running sum (for softmax denominator)

    # Tile over K, V blocks (outer loop)
    for j in range(0, n, block_size_c):
        Kj = K[j:j+block_size_c]  # Load tile from HBM to SRAM
        Vj = V[j:j+block_size_c]

        # Tile over Q blocks (inner loop)
        for i in range(0, n, block_size_r):
            Qi = Q[i:i+block_size_r]  # Load tile from HBM to SRAM

            # --- Everything below happens in SRAM ---

            # Compute local attention scores
            Sij = Qi @ Kj.T / sqrt(d)  # Shape: (Br × Bc)

            # Online softmax update
            m_new = maximum(m[i:i+block_size_r], Sij.max(dim=-1))

            # Correction factor for previous accumulation
            correction = exp(m[i:i+block_size_r] - m_new)

            # New exponentials
            P = exp(Sij - m_new.unsqueeze(-1))

            # Update running sum
            l_new = correction * l[i:i+block_size_r] + P.sum(dim=-1)

            # Update output: rescale old output + add new contribution
            O[i:i+block_size_r] = (
                correction.unsqueeze(-1) * l[i:i+block_size_r].unsqueeze(-1)
                * O[i:i+block_size_r]
                + P @ Vj
            ) / l_new.unsqueeze(-1)

            # Update statistics
            m[i:i+block_size_r] = m_new
            l[i:i+block_size_r] = l_new

            # --- End SRAM computation ---

    return O

The IO Analysis

Standard Attention

HBM accesses=O(nd+n2)reads + writes\text{HBM accesses} = O(n \cdot d + n^2) \quad \text{reads + writes}

The n2n^2 term dominates — reading and writing the full attention matrix.

Flash Attention

HBM accesses=O(n2d2M)\text{HBM accesses} = O\left(\frac{n^2 \cdot d^2}{M}\right)

Where MM is the SRAM size. Since Md2M \gg d^2 (SRAM is ~20MB, d2d^2 is typically 1282=16,384128^2 = 16{,}384 elements = ~32KB):

n2d2Mn2\frac{n^2 d^2}{M} \ll n^2

For typical values (M=20MB,d=128,n=128KM = 20\text{MB}, d = 128, n = 128\text{K}):

Standard IOn2=1.6×1010\text{Standard IO} \propto n^2 = 1.6 \times 10^{10}

Flash IOn2d2M=1.6×1010×16,38420×1061.3×1010\text{Flash IO} \propto \frac{n^2 d^2}{M} = \frac{1.6 \times 10^{10} \times 16{,}384}{20 \times 10^6} \approx 1.3 \times 10^{10}

The reduction grows with M/d2M/d^2. In practice, Flash Attention achieves 2–4× speedup over standard attention.

Flash Attention 2 and 3

Flash Attention 2 (2023)

Key improvements:

Flash Attention 3 (2024)

Targets Hopper GPUs (H100):

Impact on Context Length

Flash Attention’s memory savings directly enabled longer context:

Without Flash AttentionWith Flash Attention
Attention matrix stored: O(n2)O(n^2)No attention matrix: O(n)O(n)
128K context → 30 GB just for attention128K context → ~0 extra for attention
Maximum practical context: ~4KMaximum practical context: 1M+

The jump from 4K to 128K+ context windows in 2023–2024 was largely enabled by Flash Attention.

When Flash Attention Doesn’t Help

Flash Attention optimizes memory IO, not compute. If your bottleneck is raw computation (enough GPU memory but not enough FLOPS), Flash Attention provides less benefit.

Specifically:

For most production inference with n>4,096n > 4{,}096, Flash Attention is essential and should always be used.


ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai