The KV Cache: Why Your AI Needs So Much GPU Memory

During inference, the model stores Key and Value vectors for every token. This KV cache is often the biggest memory consumer. Here's the math behind it.

The KV Cache: Why Your AI Needs So Much GPU Memory

The KV Cache: Why Your AI Needs So Much GPU Memory

The Simple Version: Why Store Keys and Values?

Imagine you’re writing an essay, and every time you write a new word, you need to re-read the entire essay from the beginning to decide what word comes next. That would be incredibly slow, right?

That’s exactly what a transformer would do without the KV cache. For every new token it generates, it would recompute the Key and Value vectors for every previous token — even though those vectors never change.

The KV cache is the solution: store the Key and Value vectors for all previous tokens so you only compute them once. It’s like keeping sticky notes on every paragraph you’ve already written so you don’t have to re-read them.

But here’s the catch: those sticky notes take up space. A lot of space.

What’s Actually Stored

For each token in the context, the model stores two vectors per layer per attention head:

The dimensions of these vectors:

The KV Cache Memory Formula

MKV=2×L×hkv×dh×s×b×pM_{KV} = 2 \times L \times h_{kv} \times d_h \times s \times b \times p

Where:

Let’s work through a real example.

Example: Llama 3.1 70B

The Llama 3.1 70B architecture:

At 128K context:

MKV=2×80×8×128×128,000×1×2M_{KV} = 2 \times 80 \times 8 \times 128 \times 128{,}000 \times 1 \times 2

=2×80×8×128×128,000×2= 2 \times 80 \times 8 \times 128 \times 128{,}000 \times 2

=41,943,040,000 bytes= 41{,}943{,}040{,}000 \text{ bytes}

42 GB\approx 42 \text{ GB}

At 1M context:

MKV=2×80×8×128×1,000,000×1×2M_{KV} = 2 \times 80 \times 8 \times 128 \times 1{,}000{,}000 \times 1 \times 2

=327,680,000,000 bytes= 327{,}680{,}000{,}000 \text{ bytes}

328 GB\approx 328 \text{ GB}

That’s 328 GB just for the KV cache — not counting the model weights (140 GB in FP16) or activations!

KV Cache Calculator

def kv_cache_memory(
    num_layers: int,
    num_kv_heads: int,
    head_dim: int,
    seq_length: int,
    batch_size: int = 1,
    precision_bytes: int = 2,  # 2 for FP16, 1 for INT8
) -> dict:
    """
    Calculate KV cache memory requirements.

    Returns memory in bytes, MB, and GB.
    """
    # 2 for K and V
    memory_bytes = (
        2 * num_layers * num_kv_heads * head_dim
        * seq_length * batch_size * precision_bytes
    )

    return {
        "bytes": memory_bytes,
        "MB": memory_bytes / (1024 ** 2),
        "GB": memory_bytes / (1024 ** 3),
    }


# Model configurations
models = {
    "Llama 3.1 8B": {
        "num_layers": 32, "num_kv_heads": 8,
        "head_dim": 128, "model_params": "8B"
    },
    "Llama 3.1 70B": {
        "num_layers": 80, "num_kv_heads": 8,
        "head_dim": 128, "model_params": "70B"
    },
    "Llama 3.1 405B": {
        "num_layers": 126, "num_kv_heads": 8,
        "head_dim": 128, "model_params": "405B"
    },
    "Mistral 7B": {
        "num_layers": 32, "num_kv_heads": 8,
        "head_dim": 128, "model_params": "7B"
    },
}

context_lengths = [4_096, 32_768, 128_000, 1_000_000]

print(f"{'Model':<20} {'Context':>10} {'KV Cache (GB)':>15} {'Precision':>10}")
print("=" * 60)

for model_name, config in models.items():
    for ctx_len in context_lengths:
        mem = kv_cache_memory(
            num_layers=config["num_layers"],
            num_kv_heads=config["num_kv_heads"],
            head_dim=config["head_dim"],
            seq_length=ctx_len,
        )
        print(f"{model_name:<20} {ctx_len:>10,} {mem['GB']:>15.2f} {'FP16':>10}")
    print("-" * 60)

Output:

Model                   Context   KV Cache (GB)  Precision
============================================================
Llama 3.1 8B             4,096            0.50       FP16
Llama 3.1 8B            32,768            4.00       FP16
Llama 3.1 8B           128,000           15.63       FP16
Llama 3.1 8B         1,000,000          122.07       FP16
------------------------------------------------------------
Llama 3.1 70B            4,096            1.25       FP16
Llama 3.1 70B           32,768           10.00       FP16
Llama 3.1 70B          128,000           39.06       FP16
Llama 3.1 70B        1,000,000          305.18       FP16
------------------------------------------------------------
Llama 3.1 405B           4,096            1.97       FP16
Llama 3.1 405B          32,768           15.75       FP16
Llama 3.1 405B         128,000           61.52       FP16
Llama 3.1 405B       1,000,000          480.94       FP16
------------------------------------------------------------

Why KV Cache Dominates Memory

Let’s compare the KV cache to model weights for Llama 3.1 70B:

ComponentMemory (FP16)
Model weights~140 GB
KV cache at 4K1.25 GB
KV cache at 128K39 GB
KV cache at 1M305 GB

At short contexts, the KV cache is negligible. But at 1M tokens, the KV cache is 2.2× larger than the model itself.

The crossover point — where KV cache equals model weight memory — can be calculated:

scrossover=Mweights2×L×hkv×dh×ps_{\text{crossover}} = \frac{M_{\text{weights}}}{2 \times L \times h_{kv} \times d_h \times p}

For Llama 3.1 70B:

scrossover=140×1092×80×8×128×2=140×109327,680427,000 tokenss_{\text{crossover}} = \frac{140 \times 10^9}{2 \times 80 \times 8 \times 128 \times 2} = \frac{140 \times 10^9}{327{,}680} \approx 427{,}000 \text{ tokens}

Above ~427K tokens, the KV cache uses more memory than the model weights.

Grouped-Query Attention: The 8× Trick

Standard multi-head attention (MHA) uses the same number of KV heads as query heads. If the model has 64 query heads, it also has 64 KV heads.

Grouped-Query Attention (GQA) shares KV heads across groups of query heads:

Memory reduction=hQhKV\text{Memory reduction} = \frac{h_Q}{h_{KV}}

For Llama 3.1 70B: 64 query heads, 8 KV heads:

Reduction factor=648=8×\text{Reduction factor} = \frac{64}{8} = 8\times

Without GQA, the KV cache at 128K would be:

39 GB×8=312 GB39 \text{ GB} \times 8 = 312 \text{ GB}

GQA makes long-context inference possible on current hardware.

KV Cache Quantization

Another way to reduce KV cache size: use lower precision:

PrecisionBytes per ElementKV Cache at 128K (70B)
FP32478 GB
FP16 / BF16239 GB
INT8119.5 GB
INT40.59.75 GB
def compare_quantization(model_config, seq_length):
    """Compare KV cache size at different quantization levels."""
    precisions = {
        "FP32": 4,
        "FP16": 2,
        "INT8": 1,
        "INT4": 0.5,
    }

    print(f"\nKV Cache at {seq_length:,} tokens:")
    print(f"{'Precision':<10} {'Memory (GB)':>12} {'Savings':>10}")
    print("-" * 35)

    base_mem = None
    for name, p_bytes in precisions.items():
        mem = kv_cache_memory(
            **model_config,
            seq_length=seq_length,
            precision_bytes=p_bytes,
        )
        if base_mem is None:
            base_mem = mem["GB"]
        savings = (1 - mem["GB"] / base_mem) * 100
        print(f"{name:<10} {mem['GB']:>12.2f} {savings:>9.1f}%")

compare_quantization(
    {"num_layers": 80, "num_kv_heads": 8, "head_dim": 128},
    seq_length=128_000,
)

INT8 quantization halves the KV cache with minimal quality loss. INT4 is more aggressive but can degrade model quality.

Paged Attention: vLLM’s Innovation

Traditional KV caches pre-allocate contiguous memory for the maximum sequence length. If you allocate for 128K tokens but only use 10K, 92% of memory is wasted.

Paged Attention (from the vLLM project) borrows ideas from operating system virtual memory: allocate memory in small pages (blocks) and map them via a page table.

# Conceptual paged attention
class PagedKVCache:
    def __init__(self, block_size=16, max_blocks=1000):
        self.block_size = block_size
        self.blocks = {}  # block_id -> tensor
        self.page_table = {}  # sequence_id -> [block_ids]

    def allocate(self, seq_id, num_tokens):
        """Allocate only as many blocks as needed."""
        num_blocks = (num_tokens + self.block_size - 1) // self.block_size
        block_ids = [self._get_free_block() for _ in range(num_blocks)]
        self.page_table[seq_id] = block_ids
        return block_ids

    def _get_free_block(self):
        block_id = len(self.blocks)
        self.blocks[block_id] = None  # Will hold K, V tensors
        return block_id

Benefits:

The Takeaway

The KV cache is the hidden cost of long context windows. It scales linearly with sequence length and can quickly dwarf the model weights themselves. Every token you add to the context costs real GPU memory — and at $2–3 per GPU-hour for H100s, that translates directly to dollars.

Understanding KV cache math lets you:

  1. Right-size your GPU fleet for your context requirements
  2. Choose the right optimization (GQA, quantization, paging)
  3. Make informed decisions about context length vs. cost

ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai