The O(n²) Problem: Why Doubling Your Context Window Quadruples the Cost

Self-attention computes all pairwise interactions between tokens. For n tokens, that's n² computations. Here's the full mathematical derivation.

The O(n²) Problem: Why Doubling Your Context Window Quadruples the Cost

The O(n²) Problem: Why Doubling Context Quadruples Cost

The Handshake Analogy

Imagine 5 people walk into a room and everyone shakes hands with everyone else. How many handshakes?

(52)=5×42=10 handshakes\binom{5}{2} = \frac{5 \times 4}{2} = 10 \text{ handshakes}

Now 10 people:

(102)=10×92=45 handshakes\binom{10}{2} = \frac{10 \times 9}{2} = 45 \text{ handshakes}

Double the people, but 4.5× the handshakes. With 100 people: 4,950 handshakes. With 1,000 people: 499,500.

This is quadratic growth — and it’s exactly what happens in transformer attention. Every token must “interact” with every other token. Double the tokens, quadruple the computation.

The Self-Attention Computation

Recall the attention formula:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Where:

Let’s count the operations.

Step 1: Compute QKTQK^T

Multiply QQ (shape n×dkn \times d_k) by KTK^T (shape dk×nd_k \times n):

S=QKTRn×nS = QK^T \in \mathbb{R}^{n \times n}

FLOPS: Each entry SijS_{ij} requires dkd_k multiplications and dk1d_k - 1 additions ≈ 2dk2d_k FLOPS.

Total entries: n2n^2.

FLOPSQKT=n2×2dk=O(n2dk)\text{FLOPS}_{QK^T} = n^2 \times 2d_k = O(n^2 d_k)

Step 2: Scale and Softmax

Division by dk\sqrt{d_k}: O(n2)O(n^2) operations.

Softmax per row: O(n)O(n) operations × nn rows = O(n2)O(n^2).

FLOPSsoftmax=O(n2)\text{FLOPS}_{\text{softmax}} = O(n^2)

Step 3: Multiply by VV

The attention weight matrix ARn×nA \in \mathbb{R}^{n \times n} multiplied by VRn×dvV \in \mathbb{R}^{n \times d_v}:

Output=A×VRn×dv\text{Output} = A \times V \in \mathbb{R}^{n \times d_v}

FLOPS: n×dvn \times d_v output entries, each requiring nn multiplications:

FLOPSAV=n×dv×2n=O(n2dv)\text{FLOPS}_{AV} = n \times d_v \times 2n = O(n^2 d_v)

Total Per Head

FLOPShead=O(n2dk)+O(n2)+O(n2dv)=O(n2dk)\text{FLOPS}_{\text{head}} = O(n^2 d_k) + O(n^2) + O(n^2 d_v) = O(n^2 d_k)

(Since typically dk=dvd_k = d_v and dk1d_k \gg 1)

Total Across All Heads and Layers

With hh heads per layer and LL layers, where dk=d/hd_k = d/h (model dimension divided by number of heads):

FLOPStotal=L×h×O(n2×d/h)=O(L×n2×d)\text{FLOPS}_{\text{total}} = L \times h \times O(n^2 \times d/h) = O(L \times n^2 \times d)

The Quadratic Scaling in Practice

Let’s compute actual FLOPS for different context lengths:

def attention_flops(
    n_tokens: int,
    d_model: int = 8192,
    n_heads: int = 64,
    n_layers: int = 80,
) -> dict:
    """
    Calculate attention FLOPS for a transformer model.

    Uses the formula: FLOPS = L * n^2 * d * 2 * 3
    (2 for multiply-add, 3 for QK^T + softmax + AV)
    """
    d_k = d_model // n_heads

    # Per head: 2 * n^2 * d_k for QK^T + 2 * n^2 * d_k for AV
    flops_per_head = 4 * n_tokens**2 * d_k

    # All heads in one layer
    flops_per_layer = flops_per_head * n_heads  # = 4 * n^2 * d_model

    # All layers
    total_flops = flops_per_layer * n_layers

    return {
        "flops": total_flops,
        "tflops": total_flops / 1e12,
        "time_h100_sec": total_flops / (990e12),  # H100 = ~990 TFLOPS FP16
    }


# Compare different context lengths
contexts = [1_000, 4_096, 32_768, 128_000, 200_000, 1_000_000]

print(f"{'Context':>10} {'TFLOPS':>12} {'H100 Time':>12} {'Relative':>10}")
print("=" * 50)

base_flops = None
for n in contexts:
    result = attention_flops(n)
    if base_flops is None:
        base_flops = result["flops"]
    relative = result["flops"] / base_flops

    print(f"{n:>10,} {result['tflops']:>12,.1f} {result['time_h100_sec']:>11.3f}s {relative:>10,.0f}×")

Output:

   Context       TFLOPS    H100 Time   Relative
==================================================
     1,000          2.6       0.003s          1×
     4,096         43.8       0.044s         17×
    32,768      2,802.8       2.831s      1,074×
   128,000     42,762.8      43.195s     16,384×
   200,000    104,448.0     105.504s     40,000×
 1,000,000  2,611,200.0   2,637.576s  1,000,000×

Going from 1K to 1M tokens: 1,000,000× more computation for attention alone!

The Memory Wall

It’s not just compute — the n×nn \times n attention matrix needs to be stored in memory:

Attention matrix memory=n2×p bytes\text{Attention matrix memory} = n^2 \times p \text{ bytes}

Where pp is the precision (2 bytes for FP16, 4 for FP32):

Context (nn)Matrix SizeMemory (FP16)
1,00010610^62 MB
4,0961.68×1071.68 \times 10^732 MB
32,7681.07×1091.07 \times 10^92 GB
128,0001.64×10101.64 \times 10^{10}30.5 GB
1,000,000101210^{12}1,862 GB

At 1M tokens, the attention matrix alone needs 1.86 TB per head per layer. This is why Flash Attention (which never materializes the full matrix) is absolutely essential for long context.

Comparing O(n²) to O(n)

Recurrent architectures (like RNNs, LSTMs, and State Space Models) process tokens sequentially with O(n)O(n) complexity:

import time
import numpy as np

def benchmark_quadratic_vs_linear(sizes):
    """Compare O(n²) attention vs O(n) recurrent computation."""

    print(f"{'n':>10} {'O(n²) time':>12} {'O(n) time':>12} {'Ratio':>10}")
    print("=" * 50)

    for n in sizes:
        # Simulate O(n²): matrix multiply n×d by d×n
        d = 64
        Q = np.random.randn(min(n, 5000), d).astype(np.float32)
        K = np.random.randn(min(n, 5000), d).astype(np.float32)

        start = time.perf_counter()
        _ = Q @ K.T  # O(n²d)
        t_quad = time.perf_counter() - start
        # Scale to actual n
        t_quad *= (n / min(n, 5000)) ** 2

        # Simulate O(n): sequential processing
        start = time.perf_counter()
        state = np.zeros(d, dtype=np.float32)
        for i in range(min(n, 5000)):
            state = 0.9 * state + Q[i]  # O(d) per step
        t_linear = time.perf_counter() - start
        t_linear *= n / min(n, 5000)

        ratio = t_quad / t_linear if t_linear > 0 else float('inf')
        print(f"{n:>10,} {t_quad:>12.6f}s {t_linear:>12.6f}s {ratio:>10.1f}×")

benchmark_quadratic_vs_linear([100, 1_000, 10_000, 100_000])

The O(n²) approach is faster for short sequences (due to parallelization) but becomes prohibitively expensive for long sequences.

Why We Still Use O(n²)

Given the quadratic cost, why not switch to O(n) architectures entirely?

Because attention is really good. The quadratic cost buys you:

  1. Global context: Every token can directly attend to every other token
  2. Parallelism: Unlike recurrence, attention can be computed in parallel
  3. Quality: Attention consistently outperforms linear alternatives on most benchmarks

The field’s approach is not to eliminate the quadratic cost but to manage it:

The Engineering Decision

When designing an AI system, you face a fundamental tradeoff:

Quality×Context LengthBudgetCost per query\text{Quality} \times \text{Context Length} \propto \frac{\text{Budget}}{\text{Cost per query}}

Cost per queryn2\text{Cost per query} \propto n^2

Quality×Context LengthBudgetn2\therefore \text{Quality} \times \text{Context Length} \propto \frac{\text{Budget}}{n^2}

Doubling context length quadruples cost but typically provides less than 2× improvement in task performance (due to attention dilution). The cost-effectiveness of longer context decreases rapidly:

Cost-effectiveness=ΔQualityn20 as n\text{Cost-effectiveness} = \frac{\Delta \text{Quality}}{n^2} \to 0 \text{ as } n \to \infty

This is why smart context management — selecting the right tokens to include rather than including everything — is fundamentally more cost-effective than simply expanding the context window.


ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai