The exact formula for KV cache memory and worked examples for every major model architecture. Calculate your GPU requirements precisely.
Imagine you’re a librarian processing a line of visitors. For each visitor, you write two index cards — one describing what they’re looking for (Key) and one with their contact info (Value). You need to keep every card from every visitor because any future visitor might need to reference a past one.
Your filing cabinet has limited space. By visitor 1,000, you’re running out of drawers. By visitor 100,000, you need an entire warehouse.
This is the KV cache problem — and this blog gives you the exact formulas to calculate the warehouse size.
| Symbol | Meaning | Unit |
|---|---|---|
| Keys + Values | - | |
| Number of layers | count | |
| Number of KV heads | count | |
| Head dimension | elements | |
| Sequence length | tokens | |
| Batch size | count | |
| Bytes per element | bytes |
Since (total KV dimension per layer):
Bytes per token (useful for quick calculations):
# Complete model architecture specifications
MODEL_CONFIGS = {
"Llama-3.1-8B": {
"layers": 32,
"q_heads": 32,
"kv_heads": 8, # GQA: 4 groups
"head_dim": 128,
"model_dim": 4096,
"params": "8B",
"weight_memory_fp16_gb": 16,
},
"Llama-3.1-70B": {
"layers": 80,
"q_heads": 64,
"kv_heads": 8, # GQA: 8 groups
"head_dim": 128,
"model_dim": 8192,
"params": "70B",
"weight_memory_fp16_gb": 140,
},
"Llama-3.1-405B": {
"layers": 126,
"q_heads": 128,
"kv_heads": 8, # GQA: 16 groups
"head_dim": 128,
"model_dim": 16384,
"params": "405B",
"weight_memory_fp16_gb": 810,
},
"Mistral-7B": {
"layers": 32,
"q_heads": 32,
"kv_heads": 8, # GQA
"head_dim": 128,
"model_dim": 4096,
"params": "7B",
"weight_memory_fp16_gb": 14,
},
"Mixtral-8x7B": {
"layers": 32,
"q_heads": 32,
"kv_heads": 8,
"head_dim": 128,
"model_dim": 4096,
"params": "47B (active: 13B)",
"weight_memory_fp16_gb": 94,
},
"Qwen-2.5-72B": {
"layers": 80,
"q_heads": 64,
"kv_heads": 8,
"head_dim": 128,
"model_dim": 8192,
"params": "72B",
"weight_memory_fp16_gb": 144,
},
}
def kv_cache_memory(config: dict, seq_length: int, batch_size: int = 1,
precision: str = "fp16") -> dict:
"""Calculate KV cache memory for a given model configuration."""
p_bytes = {"fp32": 4, "fp16": 2, "bf16": 2, "int8": 1, "int4": 0.5}[precision]
memory_bytes = (
2 * config["layers"] * config["kv_heads"] * config["head_dim"]
* seq_length * batch_size * p_bytes
)
bytes_per_token = 2 * config["layers"] * config["kv_heads"] * config["head_dim"] * p_bytes
return {
"total_gb": memory_bytes / (1024**3),
"bytes_per_token": bytes_per_token,
"kb_per_token": bytes_per_token / 1024,
}
# Generate comprehensive table
print(f"{'Model':<22} {'Seq Len':>10} {'KV Cache':>10} {'Weights':>10} {'Total':>10} {'KV/Weight':>10}")
print("=" * 75)
for name, config in MODEL_CONFIGS.items():
for seq_len in [4_096, 32_768, 128_000, 1_000_000]:
mem = kv_cache_memory(config, seq_len)
total = mem["total_gb"] + config["weight_memory_fp16_gb"]
ratio = mem["total_gb"] / config["weight_memory_fp16_gb"] * 100
print(f"{name:<22} {seq_len:>10,} {mem['total_gb']:>9.1f}G "
f"{config['weight_memory_fp16_gb']:>9}G {total:>9.1f}G {ratio:>9.1f}%")
print("-" * 75)Standard Multi-Head Attention (MHA): Every query head has its own KV head.
Grouped-Query Attention (GQA): Multiple query heads share one KV head.
Savings factor:
For Llama 3.1 70B (, ):
def gqa_savings_analysis(config: dict, seq_length: int = 128_000):
"""Compare MHA vs GQA KV cache memory."""
# MHA: kv_heads = q_heads
mha_mem = kv_cache_memory(
{**config, "kv_heads": config["q_heads"]},
seq_length
)
# GQA: actual kv_heads
gqa_mem = kv_cache_memory(config, seq_length)
# MQA: kv_heads = 1
mqa_mem = kv_cache_memory(
{**config, "kv_heads": 1},
seq_length
)
reduction_gqa = mha_mem["total_gb"] / gqa_mem["total_gb"]
reduction_mqa = mha_mem["total_gb"] / mqa_mem["total_gb"]
print(f"KV Cache at {seq_length:,} tokens:")
print(f" MHA ({config['q_heads']} KV heads): {mha_mem['total_gb']:.1f} GB")
print(f" GQA ({config['kv_heads']} KV heads): {gqa_mem['total_gb']:.1f} GB ({reduction_gqa:.0f}× reduction)")
print(f" MQA (1 KV head): {mqa_mem['total_gb']:.1f} GB ({reduction_mqa:.0f}× reduction)")
gqa_savings_analysis(MODEL_CONFIGS["Llama-3.1-70B"])Output:
KV Cache at 128,000 tokens:
MHA (64 KV heads): 312.5 GB
GQA (8 KV heads): 39.1 GB (8× reduction)
MQA (1 KV head): 4.9 GB (64× reduction)Without GQA, running Llama 70B at 128K context would need 312 GB just for KV cache — impossible on current hardware.
def quantization_comparison(config: dict, seq_length: int):
"""Compare KV cache at different quantization levels."""
precisions = [
("FP32", "fp32"),
("FP16/BF16", "fp16"),
("INT8", "int8"),
("INT4", "int4"),
]
print(f"\n{config.get('name', 'Model')} at {seq_length:,} tokens:")
print(f"{'Precision':<12} {'KV Cache':>10} {'KB/Token':>10} {'Quality':>15}")
print("-" * 50)
for name, prec in precisions:
mem = kv_cache_memory(config, seq_length, precision=prec)
quality = {
"fp32": "Baseline",
"fp16": "~Same",
"int8": "~0.5% drop",
"int4": "~2-5% drop",
}[prec]
print(f"{name:<12} {mem['total_gb']:>9.1f}G {mem['kb_per_token']:>9.1f} {quality:>15}")
config = MODEL_CONFIGS["Llama-3.1-70B"]
config["name"] = "Llama-3.1-70B"
quantization_comparison(config, 128_000)
quantization_comparison(config, 1_000_000)Given a GPU with memory, the maximum context length is:
Where covers activations, workspace, and OS overhead (typically 2-4 GB).
def max_context_length(
config: dict,
gpu_memory_gb: float,
overhead_gb: float = 3.0,
batch_size: int = 1,
precision: str = "fp16"
) -> int:
"""Calculate maximum context length for a given GPU."""
p_bytes = {"fp32": 4, "fp16": 2, "bf16": 2, "int8": 1, "int4": 0.5}[precision]
available = (gpu_memory_gb - config["weight_memory_fp16_gb"] - overhead_gb) * (1024**3)
if available <= 0:
return 0 # Model doesn't fit
bytes_per_token = 2 * config["layers"] * config["kv_heads"] * config["head_dim"] * p_bytes * batch_size
return int(available / bytes_per_token)
# GPU configurations
gpus = [
("A100 40GB", 40),
("A100 80GB", 80),
("H100 80GB", 80),
("8× H100", 640),
("8× H200", 1120),
]
print(f"{'Model':<20} {'GPU':>12} {'Max Context':>12} {'Pages':>8}")
print("=" * 55)
for model_name, config in [("Llama-3.1-8B", MODEL_CONFIGS["Llama-3.1-8B"]),
("Llama-3.1-70B", MODEL_CONFIGS["Llama-3.1-70B"])]:
for gpu_name, gpu_mem in gpus:
max_ctx = max_context_length(config, gpu_mem)
pages = max_ctx * 0.75 / 250 if max_ctx > 0 else 0
ctx_str = f"{max_ctx:,}" if max_ctx > 0 else "N/A"
print(f"{model_name:<20} {gpu_name:>12} {ctx_str:>12} {pages:>7,.0f}")
print("-" * 55)A handy reference for quick mental math — how many bytes each token costs in the KV cache:
| Model | FP16 (bytes/token) | INT8 (bytes/token) |
|---|---|---|
| Llama 3.1 8B | 131 KB | 65 KB |
| Llama 3.1 70B | 328 KB | 164 KB |
| Llama 3.1 405B | 516 KB | 258 KB |
Quick calculation: For Llama 70B at FP16, every 3 tokens costs roughly 1 MB of GPU memory in KV cache.
At 1M tokens: MB ≈ 325 GB.
KV cache memory is the most predictable and most impactful cost in LLM deployment. With the formulas above, you can calculate the exact memory requirements for any model, any context length, any precision, and any GPU configuration — no guesswork needed.
The key insight: every token you add to the context costs real GPU memory. At scale, this directly translates to hardware costs and latency. Smart context management — including only what’s needed — saves both.
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai