The O(n²) Problem: Why Doubling Context Quadruples Cost
The Handshake Analogy
Imagine 5 people walk into a room and everyone shakes hands with everyone else. How many handshakes?
Now 10 people:
Double the people, but 4.5× the handshakes. With 100 people: 4,950 handshakes. With 1,000 people: 499,500.
This is quadratic growth — and it’s exactly what happens in transformer attention. Every token must “interact” with every other token. Double the tokens, quadruple the computation.
The Self-Attention Computation
Recall the attention formula:
Where:
- (n tokens, each with dimensions)
Let’s count the operations.
Step 1: Compute
Multiply (shape ) by (shape ):
FLOPS: Each entry requires multiplications and additions ≈ FLOPS.
Total entries: .
Step 2: Scale and Softmax
Division by : operations.
Softmax per row: operations × rows = .
Step 3: Multiply by
The attention weight matrix multiplied by :
FLOPS: output entries, each requiring multiplications:
Total Per Head
(Since typically and )
Total Across All Heads and Layers
With heads per layer and layers, where (model dimension divided by number of heads):
The Quadratic Scaling in Practice
Let’s compute actual FLOPS for different context lengths:
def attention_flops(
n_tokens: int,
d_model: int = 8192,
n_heads: int = 64,
n_layers: int = 80,
) -> dict:
"""
Calculate attention FLOPS for a transformer model.
Uses the formula: FLOPS = L * n^2 * d * 2 * 3
(2 for multiply-add, 3 for QK^T + softmax + AV)
"""
d_k = d_model // n_heads
# Per head: 2 * n^2 * d_k for QK^T + 2 * n^2 * d_k for AV
flops_per_head = 4 * n_tokens**2 * d_k
# All heads in one layer
flops_per_layer = flops_per_head * n_heads # = 4 * n^2 * d_model
# All layers
total_flops = flops_per_layer * n_layers
return {
"flops": total_flops,
"tflops": total_flops / 1e12,
"time_h100_sec": total_flops / (990e12), # H100 = ~990 TFLOPS FP16
}
# Compare different context lengths
contexts = [1_000, 4_096, 32_768, 128_000, 200_000, 1_000_000]
print(f"{'Context':>10} {'TFLOPS':>12} {'H100 Time':>12} {'Relative':>10}")
print("=" * 50)
base_flops = None
for n in contexts:
result = attention_flops(n)
if base_flops is None:
base_flops = result["flops"]
relative = result["flops"] / base_flops
print(f"{n:>10,} {result['tflops']:>12,.1f} {result['time_h100_sec']:>11.3f}s {relative:>10,.0f}×")Output:
Context TFLOPS H100 Time Relative
==================================================
1,000 2.6 0.003s 1×
4,096 43.8 0.044s 17×
32,768 2,802.8 2.831s 1,074×
128,000 42,762.8 43.195s 16,384×
200,000 104,448.0 105.504s 40,000×
1,000,000 2,611,200.0 2,637.576s 1,000,000×Going from 1K to 1M tokens: 1,000,000× more computation for attention alone!
The Memory Wall
It’s not just compute — the attention matrix needs to be stored in memory:
Where is the precision (2 bytes for FP16, 4 for FP32):
| Context () | Matrix Size | Memory (FP16) |
|---|---|---|
| 1,000 | 2 MB | |
| 4,096 | 32 MB | |
| 32,768 | 2 GB | |
| 128,000 | 30.5 GB | |
| 1,000,000 | 1,862 GB |
At 1M tokens, the attention matrix alone needs 1.86 TB per head per layer. This is why Flash Attention (which never materializes the full matrix) is absolutely essential for long context.
Comparing O(n²) to O(n)
Recurrent architectures (like RNNs, LSTMs, and State Space Models) process tokens sequentially with complexity:
import time
import numpy as np
def benchmark_quadratic_vs_linear(sizes):
"""Compare O(n²) attention vs O(n) recurrent computation."""
print(f"{'n':>10} {'O(n²) time':>12} {'O(n) time':>12} {'Ratio':>10}")
print("=" * 50)
for n in sizes:
# Simulate O(n²): matrix multiply n×d by d×n
d = 64
Q = np.random.randn(min(n, 5000), d).astype(np.float32)
K = np.random.randn(min(n, 5000), d).astype(np.float32)
start = time.perf_counter()
_ = Q @ K.T # O(n²d)
t_quad = time.perf_counter() - start
# Scale to actual n
t_quad *= (n / min(n, 5000)) ** 2
# Simulate O(n): sequential processing
start = time.perf_counter()
state = np.zeros(d, dtype=np.float32)
for i in range(min(n, 5000)):
state = 0.9 * state + Q[i] # O(d) per step
t_linear = time.perf_counter() - start
t_linear *= n / min(n, 5000)
ratio = t_quad / t_linear if t_linear > 0 else float('inf')
print(f"{n:>10,} {t_quad:>12.6f}s {t_linear:>12.6f}s {ratio:>10.1f}×")
benchmark_quadratic_vs_linear([100, 1_000, 10_000, 100_000])The O(n²) approach is faster for short sequences (due to parallelization) but becomes prohibitively expensive for long sequences.
Why We Still Use O(n²)
Given the quadratic cost, why not switch to O(n) architectures entirely?
Because attention is really good. The quadratic cost buys you:
- Global context: Every token can directly attend to every other token
- Parallelism: Unlike recurrence, attention can be computed in parallel
- Quality: Attention consistently outperforms linear alternatives on most benchmarks
The field’s approach is not to eliminate the quadratic cost but to manage it:
- Flash Attention: Same O(n²) compute, but dramatically less memory IO
- Sparse Attention: Compute attention only between select pairs of tokens
- KV Cache Compression: Reduce the effective n by compressing old key-value pairs
The Engineering Decision
When designing an AI system, you face a fundamental tradeoff:
Doubling context length quadruples cost but typically provides less than 2× improvement in task performance (due to attention dilution). The cost-effectiveness of longer context decreases rapidly:
This is why smart context management — selecting the right tokens to include rather than including everything — is fundamentally more cost-effective than simply expanding the context window.
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Private Code Context retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai