The KV Cache: Why Your AI Needs So Much GPU Memory
The Simple Version: Why Store Keys and Values?
Imagine you’re writing an essay, and every time you write a new word, you need to re-read the entire essay from the beginning to decide what word comes next. That would be incredibly slow, right?
That’s exactly what a transformer would do without the KV cache. For every new token it generates, it would recompute the Key and Value vectors for every previous token — even though those vectors never change.
The KV cache is the solution: store the Key and Value vectors for all previous tokens so you only compute them once. It’s like keeping sticky notes on every paragraph you’ve already written so you don’t have to re-read them.
But here’s the catch: those sticky notes take up space. A lot of space.
What’s Actually Stored
For each token in the context, the model stores two vectors per layer per attention head:
- Key vector (): Used to determine relevance
- Value vector (): Used to contribute information
The dimensions of these vectors:
- Each vector has dimensions (the head dimension)
- There are key-value heads per layer
- There are layers in the model
- There are tokens in the sequence
The KV Cache Memory Formula
Where:
- = Keys and Values (two matrices)
- = number of transformer layers
- = number of KV attention heads
- = dimension per head
- = sequence length (number of tokens)
- = batch size (number of concurrent requests)
- = bytes per parameter (2 for FP16, 1 for INT8)
Let’s work through a real example.
Example: Llama 3.1 70B
The Llama 3.1 70B architecture:
- layers
- KV heads (using Grouped-Query Attention)
- dimensions per head
- bytes (FP16)
- (single request)
At 128K context:
At 1M context:
That’s 328 GB just for the KV cache — not counting the model weights (140 GB in FP16) or activations!
KV Cache Calculator
def kv_cache_memory(
num_layers: int,
num_kv_heads: int,
head_dim: int,
seq_length: int,
batch_size: int = 1,
precision_bytes: int = 2, # 2 for FP16, 1 for INT8
) -> dict:
"""
Calculate KV cache memory requirements.
Returns memory in bytes, MB, and GB.
"""
# 2 for K and V
memory_bytes = (
2 * num_layers * num_kv_heads * head_dim
* seq_length * batch_size * precision_bytes
)
return {
"bytes": memory_bytes,
"MB": memory_bytes / (1024 ** 2),
"GB": memory_bytes / (1024 ** 3),
}
# Model configurations
models = {
"Llama 3.1 8B": {
"num_layers": 32, "num_kv_heads": 8,
"head_dim": 128, "model_params": "8B"
},
"Llama 3.1 70B": {
"num_layers": 80, "num_kv_heads": 8,
"head_dim": 128, "model_params": "70B"
},
"Llama 3.1 405B": {
"num_layers": 126, "num_kv_heads": 8,
"head_dim": 128, "model_params": "405B"
},
"Mistral 7B": {
"num_layers": 32, "num_kv_heads": 8,
"head_dim": 128, "model_params": "7B"
},
}
context_lengths = [4_096, 32_768, 128_000, 1_000_000]
print(f"{'Model':<20} {'Context':>10} {'KV Cache (GB)':>15} {'Precision':>10}")
print("=" * 60)
for model_name, config in models.items():
for ctx_len in context_lengths:
mem = kv_cache_memory(
num_layers=config["num_layers"],
num_kv_heads=config["num_kv_heads"],
head_dim=config["head_dim"],
seq_length=ctx_len,
)
print(f"{model_name:<20} {ctx_len:>10,} {mem['GB']:>15.2f} {'FP16':>10}")
print("-" * 60)Output:
Model Context KV Cache (GB) Precision
============================================================
Llama 3.1 8B 4,096 0.50 FP16
Llama 3.1 8B 32,768 4.00 FP16
Llama 3.1 8B 128,000 15.63 FP16
Llama 3.1 8B 1,000,000 122.07 FP16
------------------------------------------------------------
Llama 3.1 70B 4,096 1.25 FP16
Llama 3.1 70B 32,768 10.00 FP16
Llama 3.1 70B 128,000 39.06 FP16
Llama 3.1 70B 1,000,000 305.18 FP16
------------------------------------------------------------
Llama 3.1 405B 4,096 1.97 FP16
Llama 3.1 405B 32,768 15.75 FP16
Llama 3.1 405B 128,000 61.52 FP16
Llama 3.1 405B 1,000,000 480.94 FP16
------------------------------------------------------------Why KV Cache Dominates Memory
Let’s compare the KV cache to model weights for Llama 3.1 70B:
| Component | Memory (FP16) |
|---|---|
| Model weights | ~140 GB |
| KV cache at 4K | 1.25 GB |
| KV cache at 128K | 39 GB |
| KV cache at 1M | 305 GB |
At short contexts, the KV cache is negligible. But at 1M tokens, the KV cache is 2.2× larger than the model itself.
The crossover point — where KV cache equals model weight memory — can be calculated:
For Llama 3.1 70B:
Above ~427K tokens, the KV cache uses more memory than the model weights.
Grouped-Query Attention: The 8× Trick
Standard multi-head attention (MHA) uses the same number of KV heads as query heads. If the model has 64 query heads, it also has 64 KV heads.
Grouped-Query Attention (GQA) shares KV heads across groups of query heads:
For Llama 3.1 70B: 64 query heads, 8 KV heads:
Without GQA, the KV cache at 128K would be:
GQA makes long-context inference possible on current hardware.
KV Cache Quantization
Another way to reduce KV cache size: use lower precision:
| Precision | Bytes per Element | KV Cache at 128K (70B) |
|---|---|---|
| FP32 | 4 | 78 GB |
| FP16 / BF16 | 2 | 39 GB |
| INT8 | 1 | 19.5 GB |
| INT4 | 0.5 | 9.75 GB |
def compare_quantization(model_config, seq_length):
"""Compare KV cache size at different quantization levels."""
precisions = {
"FP32": 4,
"FP16": 2,
"INT8": 1,
"INT4": 0.5,
}
print(f"\nKV Cache at {seq_length:,} tokens:")
print(f"{'Precision':<10} {'Memory (GB)':>12} {'Savings':>10}")
print("-" * 35)
base_mem = None
for name, p_bytes in precisions.items():
mem = kv_cache_memory(
**model_config,
seq_length=seq_length,
precision_bytes=p_bytes,
)
if base_mem is None:
base_mem = mem["GB"]
savings = (1 - mem["GB"] / base_mem) * 100
print(f"{name:<10} {mem['GB']:>12.2f} {savings:>9.1f}%")
compare_quantization(
{"num_layers": 80, "num_kv_heads": 8, "head_dim": 128},
seq_length=128_000,
)INT8 quantization halves the KV cache with minimal quality loss. INT4 is more aggressive but can degrade model quality.
Paged Attention: vLLM’s Innovation
Traditional KV caches pre-allocate contiguous memory for the maximum sequence length. If you allocate for 128K tokens but only use 10K, 92% of memory is wasted.
Paged Attention (from the vLLM project) borrows ideas from operating system virtual memory: allocate memory in small pages (blocks) and map them via a page table.
# Conceptual paged attention
class PagedKVCache:
def __init__(self, block_size=16, max_blocks=1000):
self.block_size = block_size
self.blocks = {} # block_id -> tensor
self.page_table = {} # sequence_id -> [block_ids]
def allocate(self, seq_id, num_tokens):
"""Allocate only as many blocks as needed."""
num_blocks = (num_tokens + self.block_size - 1) // self.block_size
block_ids = [self._get_free_block() for _ in range(num_blocks)]
self.page_table[seq_id] = block_ids
return block_ids
def _get_free_block(self):
block_id = len(self.blocks)
self.blocks[block_id] = None # Will hold K, V tensors
return block_idBenefits:
- No wasted memory — only allocate what you use
- Easy sharing — multiple sequences can share blocks (e.g., same system prompt)
- Dynamic growth — sequence can grow without pre-allocation
The Takeaway
The KV cache is the hidden cost of long context windows. It scales linearly with sequence length and can quickly dwarf the model weights themselves. Every token you add to the context costs real GPU memory — and at $2–3 per GPU-hour for H100s, that translates directly to dollars.
Understanding KV cache math lets you:
- Right-size your GPU fleet for your context requirements
- Choose the right optimization (GQA, quantization, paging)
- Make informed decisions about context length vs. cost
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Private Code Context retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai