During inference, the model stores Key and Value vectors for every token. This KV cache is often the biggest memory consumer. Here's the math behind it.
Imagine you’re writing an essay, and every time you write a new word, you need to re-read the entire essay from the beginning to decide what word comes next. That would be incredibly slow, right?
That’s exactly what a transformer would do without the KV cache. For every new token it generates, it would recompute the Key and Value vectors for every previous token — even though those vectors never change.
The KV cache is the solution: store the Key and Value vectors for all previous tokens so you only compute them once. It’s like keeping sticky notes on every paragraph you’ve already written so you don’t have to re-read them.
But here’s the catch: those sticky notes take up space. A lot of space.
For each token in the context, the model stores two vectors per layer per attention head:
The dimensions of these vectors:
Where:
Let’s work through a real example.
The Llama 3.1 70B architecture:
At 128K context:
At 1M context:
That’s 328 GB just for the KV cache — not counting the model weights (140 GB in FP16) or activations!
def kv_cache_memory(
num_layers: int,
num_kv_heads: int,
head_dim: int,
seq_length: int,
batch_size: int = 1,
precision_bytes: int = 2, # 2 for FP16, 1 for INT8
) -> dict:
"""
Calculate KV cache memory requirements.
Returns memory in bytes, MB, and GB.
"""
# 2 for K and V
memory_bytes = (
2 * num_layers * num_kv_heads * head_dim
* seq_length * batch_size * precision_bytes
)
return {
"bytes": memory_bytes,
"MB": memory_bytes / (1024 ** 2),
"GB": memory_bytes / (1024 ** 3),
}
# Model configurations
models = {
"Llama 3.1 8B": {
"num_layers": 32, "num_kv_heads": 8,
"head_dim": 128, "model_params": "8B"
},
"Llama 3.1 70B": {
"num_layers": 80, "num_kv_heads": 8,
"head_dim": 128, "model_params": "70B"
},
"Llama 3.1 405B": {
"num_layers": 126, "num_kv_heads": 8,
"head_dim": 128, "model_params": "405B"
},
"Mistral 7B": {
"num_layers": 32, "num_kv_heads": 8,
"head_dim": 128, "model_params": "7B"
},
}
context_lengths = [4_096, 32_768, 128_000, 1_000_000]
print(f"{'Model':<20} {'Context':>10} {'KV Cache (GB)':>15} {'Precision':>10}")
print("=" * 60)
for model_name, config in models.items():
for ctx_len in context_lengths:
mem = kv_cache_memory(
num_layers=config["num_layers"],
num_kv_heads=config["num_kv_heads"],
head_dim=config["head_dim"],
seq_length=ctx_len,
)
print(f"{model_name:<20} {ctx_len:>10,} {mem['GB']:>15.2f} {'FP16':>10}")
print("-" * 60)Output:
Model Context KV Cache (GB) Precision
============================================================
Llama 3.1 8B 4,096 0.50 FP16
Llama 3.1 8B 32,768 4.00 FP16
Llama 3.1 8B 128,000 15.63 FP16
Llama 3.1 8B 1,000,000 122.07 FP16
------------------------------------------------------------
Llama 3.1 70B 4,096 1.25 FP16
Llama 3.1 70B 32,768 10.00 FP16
Llama 3.1 70B 128,000 39.06 FP16
Llama 3.1 70B 1,000,000 305.18 FP16
------------------------------------------------------------
Llama 3.1 405B 4,096 1.97 FP16
Llama 3.1 405B 32,768 15.75 FP16
Llama 3.1 405B 128,000 61.52 FP16
Llama 3.1 405B 1,000,000 480.94 FP16
------------------------------------------------------------Let’s compare the KV cache to model weights for Llama 3.1 70B:
| Component | Memory (FP16) |
|---|---|
| Model weights | ~140 GB |
| KV cache at 4K | 1.25 GB |
| KV cache at 128K | 39 GB |
| KV cache at 1M | 305 GB |
At short contexts, the KV cache is negligible. But at 1M tokens, the KV cache is 2.2× larger than the model itself.
The crossover point — where KV cache equals model weight memory — can be calculated:
For Llama 3.1 70B:
Above ~427K tokens, the KV cache uses more memory than the model weights.
Standard multi-head attention (MHA) uses the same number of KV heads as query heads. If the model has 64 query heads, it also has 64 KV heads.
Grouped-Query Attention (GQA) shares KV heads across groups of query heads:
For Llama 3.1 70B: 64 query heads, 8 KV heads:
Without GQA, the KV cache at 128K would be:
GQA makes long-context inference possible on current hardware.
Another way to reduce KV cache size: use lower precision:
| Precision | Bytes per Element | KV Cache at 128K (70B) |
|---|---|---|
| FP32 | 4 | 78 GB |
| FP16 / BF16 | 2 | 39 GB |
| INT8 | 1 | 19.5 GB |
| INT4 | 0.5 | 9.75 GB |
def compare_quantization(model_config, seq_length):
"""Compare KV cache size at different quantization levels."""
precisions = {
"FP32": 4,
"FP16": 2,
"INT8": 1,
"INT4": 0.5,
}
print(f"\nKV Cache at {seq_length:,} tokens:")
print(f"{'Precision':<10} {'Memory (GB)':>12} {'Savings':>10}")
print("-" * 35)
base_mem = None
for name, p_bytes in precisions.items():
mem = kv_cache_memory(
**model_config,
seq_length=seq_length,
precision_bytes=p_bytes,
)
if base_mem is None:
base_mem = mem["GB"]
savings = (1 - mem["GB"] / base_mem) * 100
print(f"{name:<10} {mem['GB']:>12.2f} {savings:>9.1f}%")
compare_quantization(
{"num_layers": 80, "num_kv_heads": 8, "head_dim": 128},
seq_length=128_000,
)INT8 quantization halves the KV cache with minimal quality loss. INT4 is more aggressive but can degrade model quality.
Traditional KV caches pre-allocate contiguous memory for the maximum sequence length. If you allocate for 128K tokens but only use 10K, 92% of memory is wasted.
Paged Attention (from the vLLM project) borrows ideas from operating system virtual memory: allocate memory in small pages (blocks) and map them via a page table.
# Conceptual paged attention
class PagedKVCache:
def __init__(self, block_size=16, max_blocks=1000):
self.block_size = block_size
self.blocks = {} # block_id -> tensor
self.page_table = {} # sequence_id -> [block_ids]
def allocate(self, seq_id, num_tokens):
"""Allocate only as many blocks as needed."""
num_blocks = (num_tokens + self.block_size - 1) // self.block_size
block_ids = [self._get_free_block() for _ in range(num_blocks)]
self.page_table[seq_id] = block_ids
return block_ids
def _get_free_block(self):
block_id = len(self.blocks)
self.blocks[block_id] = None # Will hold K, V tensors
return block_idBenefits:
The KV cache is the hidden cost of long context windows. It scales linearly with sequence length and can quickly dwarf the model weights themselves. Every token you add to the context costs real GPU memory — and at $2–3 per GPU-hour for H100s, that translates directly to dollars.
Understanding KV cache math lets you:
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai