KV Cache Memory Math: Calculating Exactly How Much VRAM You Need

The exact formula for KV cache memory and worked examples for every major model architecture. Calculate your GPU requirements precisely.

KV Cache Memory Math: Calculating Exactly How Much VRAM You Need

KV Cache Memory Math: Calculating Exactly How Much VRAM You Need

The Filing Cabinet Analogy

Imagine you’re a librarian processing a line of visitors. For each visitor, you write two index cards — one describing what they’re looking for (Key) and one with their contact info (Value). You need to keep every card from every visitor because any future visitor might need to reference a past one.

Your filing cabinet has limited space. By visitor 1,000, you’re running out of drawers. By visitor 100,000, you need an entire warehouse.

This is the KV cache problem — and this blog gives you the exact formulas to calculate the warehouse size.

The Master Formula

MKV=2×L×hkv×dh×s×b×pM_{KV} = 2 \times L \times h_{kv} \times d_h \times s \times b \times p

SymbolMeaningUnit
22Keys + Values-
LLNumber of layerscount
hkvh_{kv}Number of KV headscount
dhd_hHead dimensionelements
ssSequence lengthtokens
bbBatch sizecount
ppBytes per elementbytes

Simplified Form

Since hkv×dh=dkvh_{kv} \times d_h = d_{kv} (total KV dimension per layer):

MKV=2×L×dkv×s×b×pM_{KV} = 2 \times L \times d_{kv} \times s \times b \times p

Bytes per token (useful for quick calculations):

β=2×L×dkv×p(bytes per token in KV cache)\beta = 2 \times L \times d_{kv} \times p \quad \text{(bytes per token in KV cache)}

MKV=β×s×bM_{KV} = \beta \times s \times b

Model Architecture Reference

# Complete model architecture specifications
MODEL_CONFIGS = {
    "Llama-3.1-8B": {
        "layers": 32,
        "q_heads": 32,
        "kv_heads": 8,      # GQA: 4 groups
        "head_dim": 128,
        "model_dim": 4096,
        "params": "8B",
        "weight_memory_fp16_gb": 16,
    },
    "Llama-3.1-70B": {
        "layers": 80,
        "q_heads": 64,
        "kv_heads": 8,      # GQA: 8 groups
        "head_dim": 128,
        "model_dim": 8192,
        "params": "70B",
        "weight_memory_fp16_gb": 140,
    },
    "Llama-3.1-405B": {
        "layers": 126,
        "q_heads": 128,
        "kv_heads": 8,      # GQA: 16 groups
        "head_dim": 128,
        "model_dim": 16384,
        "params": "405B",
        "weight_memory_fp16_gb": 810,
    },
    "Mistral-7B": {
        "layers": 32,
        "q_heads": 32,
        "kv_heads": 8,      # GQA
        "head_dim": 128,
        "model_dim": 4096,
        "params": "7B",
        "weight_memory_fp16_gb": 14,
    },
    "Mixtral-8x7B": {
        "layers": 32,
        "q_heads": 32,
        "kv_heads": 8,
        "head_dim": 128,
        "model_dim": 4096,
        "params": "47B (active: 13B)",
        "weight_memory_fp16_gb": 94,
    },
    "Qwen-2.5-72B": {
        "layers": 80,
        "q_heads": 64,
        "kv_heads": 8,
        "head_dim": 128,
        "model_dim": 8192,
        "params": "72B",
        "weight_memory_fp16_gb": 144,
    },
}


def kv_cache_memory(config: dict, seq_length: int, batch_size: int = 1,
                     precision: str = "fp16") -> dict:
    """Calculate KV cache memory for a given model configuration."""
    p_bytes = {"fp32": 4, "fp16": 2, "bf16": 2, "int8": 1, "int4": 0.5}[precision]

    memory_bytes = (
        2 * config["layers"] * config["kv_heads"] * config["head_dim"]
        * seq_length * batch_size * p_bytes
    )

    bytes_per_token = 2 * config["layers"] * config["kv_heads"] * config["head_dim"] * p_bytes

    return {
        "total_gb": memory_bytes / (1024**3),
        "bytes_per_token": bytes_per_token,
        "kb_per_token": bytes_per_token / 1024,
    }


# Generate comprehensive table
print(f"{'Model':<22} {'Seq Len':>10} {'KV Cache':>10} {'Weights':>10} {'Total':>10} {'KV/Weight':>10}")
print("=" * 75)

for name, config in MODEL_CONFIGS.items():
    for seq_len in [4_096, 32_768, 128_000, 1_000_000]:
        mem = kv_cache_memory(config, seq_len)
        total = mem["total_gb"] + config["weight_memory_fp16_gb"]
        ratio = mem["total_gb"] / config["weight_memory_fp16_gb"] * 100
        print(f"{name:<22} {seq_len:>10,} {mem['total_gb']:>9.1f}G "
              f"{config['weight_memory_fp16_gb']:>9}G {total:>9.1f}G {ratio:>9.1f}%")
    print("-" * 75)

GQA Savings Calculation

Standard Multi-Head Attention (MHA): Every query head has its own KV head.

MKVMHA=2×L×hQ×dh×s×pM_{KV}^{MHA} = 2 \times L \times h_Q \times d_h \times s \times p

Grouped-Query Attention (GQA): Multiple query heads share one KV head.

MKVGQA=2×L×hKV×dh×s×pM_{KV}^{GQA} = 2 \times L \times h_{KV} \times d_h \times s \times p

Savings factor:

Reduction=hQhKV\text{Reduction} = \frac{h_Q}{h_{KV}}

For Llama 3.1 70B (hQ=64h_Q = 64, hKV=8h_{KV} = 8):

Reduction=648=8×\text{Reduction} = \frac{64}{8} = 8\times

def gqa_savings_analysis(config: dict, seq_length: int = 128_000):
    """Compare MHA vs GQA KV cache memory."""

    # MHA: kv_heads = q_heads
    mha_mem = kv_cache_memory(
        {**config, "kv_heads": config["q_heads"]},
        seq_length
    )

    # GQA: actual kv_heads
    gqa_mem = kv_cache_memory(config, seq_length)

    # MQA: kv_heads = 1
    mqa_mem = kv_cache_memory(
        {**config, "kv_heads": 1},
        seq_length
    )

    reduction_gqa = mha_mem["total_gb"] / gqa_mem["total_gb"]
    reduction_mqa = mha_mem["total_gb"] / mqa_mem["total_gb"]

    print(f"KV Cache at {seq_length:,} tokens:")
    print(f"  MHA ({config['q_heads']} KV heads):  {mha_mem['total_gb']:.1f} GB")
    print(f"  GQA ({config['kv_heads']} KV heads):   {gqa_mem['total_gb']:.1f} GB  ({reduction_gqa:.0f}× reduction)")
    print(f"  MQA (1 KV head):     {mqa_mem['total_gb']:.1f} GB  ({reduction_mqa:.0f}× reduction)")

gqa_savings_analysis(MODEL_CONFIGS["Llama-3.1-70B"])

Output:

KV Cache at 128,000 tokens:
  MHA (64 KV heads):  312.5 GB
  GQA (8 KV heads):    39.1 GB  (8× reduction)
  MQA (1 KV head):      4.9 GB  (64× reduction)

Without GQA, running Llama 70B at 128K context would need 312 GB just for KV cache — impossible on current hardware.

Quantization Impact

def quantization_comparison(config: dict, seq_length: int):
    """Compare KV cache at different quantization levels."""
    precisions = [
        ("FP32", "fp32"),
        ("FP16/BF16", "fp16"),
        ("INT8", "int8"),
        ("INT4", "int4"),
    ]

    print(f"\n{config.get('name', 'Model')} at {seq_length:,} tokens:")
    print(f"{'Precision':<12} {'KV Cache':>10} {'KB/Token':>10} {'Quality':>15}")
    print("-" * 50)

    for name, prec in precisions:
        mem = kv_cache_memory(config, seq_length, precision=prec)
        quality = {
            "fp32": "Baseline",
            "fp16": "~Same",
            "int8": "~0.5% drop",
            "int4": "~2-5% drop",
        }[prec]
        print(f"{name:<12} {mem['total_gb']:>9.1f}G {mem['kb_per_token']:>9.1f} {quality:>15}")

config = MODEL_CONFIGS["Llama-3.1-70B"]
config["name"] = "Llama-3.1-70B"
quantization_comparison(config, 128_000)
quantization_comparison(config, 1_000_000)

Maximum Context Length per GPU

Given a GPU with MGPUM_{GPU} memory, the maximum context length is:

smax=MGPUMweightsMoverhead2×L×hkv×dh×b×ps_{\max} = \frac{M_{GPU} - M_{\text{weights}} - M_{\text{overhead}}}{2 \times L \times h_{kv} \times d_h \times b \times p}

Where MoverheadM_{\text{overhead}} covers activations, workspace, and OS overhead (typically 2-4 GB).

def max_context_length(
    config: dict,
    gpu_memory_gb: float,
    overhead_gb: float = 3.0,
    batch_size: int = 1,
    precision: str = "fp16"
) -> int:
    """Calculate maximum context length for a given GPU."""
    p_bytes = {"fp32": 4, "fp16": 2, "bf16": 2, "int8": 1, "int4": 0.5}[precision]

    available = (gpu_memory_gb - config["weight_memory_fp16_gb"] - overhead_gb) * (1024**3)

    if available <= 0:
        return 0  # Model doesn't fit

    bytes_per_token = 2 * config["layers"] * config["kv_heads"] * config["head_dim"] * p_bytes * batch_size

    return int(available / bytes_per_token)


# GPU configurations
gpus = [
    ("A100 40GB", 40),
    ("A100 80GB", 80),
    ("H100 80GB", 80),
    ("8× H100", 640),
    ("8× H200", 1120),
]

print(f"{'Model':<20} {'GPU':>12} {'Max Context':>12} {'Pages':>8}")
print("=" * 55)

for model_name, config in [("Llama-3.1-8B", MODEL_CONFIGS["Llama-3.1-8B"]),
                             ("Llama-3.1-70B", MODEL_CONFIGS["Llama-3.1-70B"])]:
    for gpu_name, gpu_mem in gpus:
        max_ctx = max_context_length(config, gpu_mem)
        pages = max_ctx * 0.75 / 250 if max_ctx > 0 else 0
        ctx_str = f"{max_ctx:,}" if max_ctx > 0 else "N/A"
        print(f"{model_name:<20} {gpu_name:>12} {ctx_str:>12} {pages:>7,.0f}")
    print("-" * 55)

Bytes Per Token Reference

A handy reference for quick mental math — how many bytes each token costs in the KV cache:

ModelFP16 (bytes/token)INT8 (bytes/token)
Llama 3.1 8B131 KB65 KB
Llama 3.1 70B328 KB164 KB
Llama 3.1 405B516 KB258 KB

Quick calculation: For Llama 70B at FP16, every 3 tokens costs roughly 1 MB of GPU memory in KV cache.

At 1M tokens: 1,000,000/3333,0001{,}000{,}000 / 3 \approx 333{,}000 MB ≈ 325 GB.

The Takeaway

KV cache memory is the most predictable and most impactful cost in LLM deployment. With the formulas above, you can calculate the exact memory requirements for any model, any context length, any precision, and any GPU configuration — no guesswork needed.

The key insight: every token you add to the context costs real GPU memory. At scale, this directly translates to hardware costs and latency. Smart context management — including only what’s needed — saves both.


ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai