Multi-Query Attention and Grouped-Query Attention: Reducing KV Cache by 8× at the Architecture Level

Standard multi-head attention uses separate K and V for each head. MQA and GQA share them — reducing KV cache dramatically with minimal quality loss.

Multi-Query Attention and Grouped-Query Attention: Reducing KV Cache by 8× at the Architecture Level

Multi-Query and Grouped-Query Attention

The Shared Textbook Analogy

In a class of 64 students (attention heads), imagine each student needs their own copy of the textbook (KV pair). That’s 64 textbooks. Expensive!

Multi-Query Attention (MQA): All 64 students share ONE textbook. Cheap, but students are often waiting for their turn.

Grouped-Query Attention (GQA): Groups of 8 students share one textbook. 8 textbooks total — much cheaper than 64, and the wait is manageable.

This is exactly how modern LLMs reduce KV cache memory.

Standard Multi-Head Attention (MHA)

In standard MHA, each head ii has its own Q, K, V projections:

Q(i)=XWQ(i),K(i)=XWK(i),V(i)=XWV(i)Q^{(i)} = X W_Q^{(i)}, \quad K^{(i)} = X W_K^{(i)}, \quad V^{(i)} = X W_V^{(i)}

headi=Attention(Q(i),K(i),V(i))\text{head}_i = \text{Attention}(Q^{(i)}, K^{(i)}, V^{(i)})

KV cache per layer: 2×h×dh×s2 \times h \times d_h \times s elements.

For a model with h=64h = 64 heads, dh=128d_h = 128, at s=128,000s = 128{,}000 tokens:

MKVMHA=2×64×128×128,000×2=4.19 GB per layerM_{\text{KV}}^{\text{MHA}} = 2 \times 64 \times 128 \times 128{,}000 \times 2 = 4.19 \text{ GB per layer}

With 80 layers: 335 GB just for KV cache.

Multi-Query Attention (MQA)

MQA (Shazeer, 2019): All heads share a single K and V:

Q(i)=XWQ(i)(separate per head)Q^{(i)} = X W_Q^{(i)} \quad \text{(separate per head)}

K=XWK,V=XWV(shared across all heads)K = X W_K, \quad V = X W_V \quad \text{(shared across all heads)}

Each head still has its own query, so they “ask different questions.” But they all look at the same keys and values.

MKVMQA=2×1×dh×sM_{\text{KV}}^{\text{MQA}} = 2 \times 1 \times d_h \times s

Reduction factor: h=64×h = 64\times

The Quality Tradeoff

MQA reduces memory by 64× but each head sees the same KV, limiting the diversity of information each head can retrieve. Empirical results show 1-3% quality degradation — small but not zero.

Grouped-Query Attention (GQA)

GQA (Ainslie et al., 2023): A middle ground. Heads are divided into gg groups, each sharing one KV pair:

Group j:K(j)=XWK(j),V(j)=XWV(j)\text{Group } j: \quad K^{(j)} = X W_K^{(j)}, \quad V^{(j)} = X W_V^{(j)}

Heads in group j:Q(i)=XWQ(i)for igroupj\text{Heads in group } j: \quad Q^{(i)} = X W_Q^{(i)} \quad \text{for } i \in \text{group}_j

With h=64h = 64 heads and g=8g = 8 groups: each group has 8 heads sharing one KV.

MKVGQA=2×g×dh×sM_{\text{KV}}^{\text{GQA}} = 2 \times g \times d_h \times s

Reduction factor: h/g=64/8=8×h/g = 64/8 = 8\times

The Sweet Spot

MethodKV HeadsKV Cache (per layer)QualityReduction
MHA642×64×dh×s2 \times 64 \times d_h \times sBaseline
GQA-882×8×dh×s2 \times 8 \times d_h \times s~99%
GQA-442×4×dh×s2 \times 4 \times d_h \times s~98%16×
GQA-222×2×dh×s2 \times 2 \times d_h \times s~97%32×
MQA12×1×dh×s2 \times 1 \times d_h \times s~97%64×
import numpy as np

def compare_attention_variants(
    n_query_heads: int = 64,
    head_dim: int = 128,
    seq_length: int = 128_000,
    n_layers: int = 80,
    precision_bytes: int = 2,
):
    """Compare MHA, GQA, and MQA KV cache requirements."""

    variants = [
        ("MHA", n_query_heads),
        ("GQA-16", 16),
        ("GQA-8", 8),
        ("GQA-4", 4),
        ("GQA-2", 2),
        ("MQA", 1),
    ]

    print(f"Config: {n_query_heads} Q-heads, d_h={head_dim}, "
          f"seq={seq_length:,}, {n_layers} layers, FP16")
    print()
    print(f"{'Variant':<10} {'KV Heads':>10} {'KV/Layer':>10} "
          f"{'Total KV':>10} {'Reduction':>10}")
    print("=" * 55)

    base_mem = None
    for name, kv_heads in variants:
        per_layer = 2 * kv_heads * head_dim * seq_length * precision_bytes
        total = per_layer * n_layers
        total_gb = total / (1024**3)

        if base_mem is None:
            base_mem = total_gb
        reduction = base_mem / total_gb

        print(f"{name:<10} {kv_heads:>10} {per_layer/(1024**3):>9.1f}G "
              f"{total_gb:>9.1f}G {reduction:>10.0f}×")

compare_attention_variants()

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    """Standard Multi-Head Attention (MHA)."""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)  # Full size
        self.W_V = nn.Linear(d_model, d_model)  # Full size
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape
        Q = self.W_Q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)

        attn = F.scaled_dot_product_attention(Q, K, V)
        return self.W_O(attn.transpose(1, 2).reshape(B, N, D))


class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention (GQA)."""

    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_groups = n_heads // n_kv_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, n_heads * self.d_k)
        self.W_K = nn.Linear(d_model, n_kv_heads * self.d_k)  # Smaller!
        self.W_V = nn.Linear(d_model, n_kv_heads * self.d_k)  # Smaller!
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape

        Q = self.W_Q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, N, self.n_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, N, self.n_kv_heads, self.d_k).transpose(1, 2)

        # Expand KV heads to match Q heads by repeating
        K = K.repeat_interleave(self.n_groups, dim=1)  # (B, n_heads, N, d_k)
        V = V.repeat_interleave(self.n_groups, dim=1)

        attn = F.scaled_dot_product_attention(Q, K, V)
        return self.W_O(attn.transpose(1, 2).reshape(B, N, D))


class MultiQueryAttention(nn.Module):
    """Multi-Query Attention (MQA) — GQA with 1 KV head."""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, n_heads * self.d_k)
        self.W_K = nn.Linear(d_model, self.d_k)  # Single head!
        self.W_V = nn.Linear(d_model, self.d_k)  # Single head!
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape

        Q = self.W_Q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, N, 1, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, N, 1, self.d_k).transpose(1, 2)

        # Broadcast single KV head to all Q heads
        K = K.expand(-1, self.n_heads, -1, -1)
        V = V.expand(-1, self.n_heads, -1, -1)

        attn = F.scaled_dot_product_attention(Q, K, V)
        return self.W_O(attn.transpose(1, 2).reshape(B, N, D))


# Parameter count comparison
d_model = 8192
n_heads = 64

mha = MultiHeadAttention(d_model, n_heads)
gqa = GroupedQueryAttention(d_model, n_heads, n_kv_heads=8)
mqa = MultiQueryAttention(d_model, n_heads)

for name, module in [("MHA", mha), ("GQA-8", gqa), ("MQA", mqa)]:
    kv_params = sum(p.numel() for n, p in module.named_parameters()
                    if 'W_K' in n or 'W_V' in n)
    total_params = sum(p.numel() for p in module.parameters())
    print(f"{name:<8} KV params: {kv_params:>12,}  Total: {total_params:>12,}")

Why Llama 3 Chose GQA-8

Llama 3 (70B) uses 64 query heads and 8 KV heads (g=8g = 8). The reasoning:

  1. 8× KV cache reduction makes 128K context feasible on 8×H100
  2. Minimal quality loss (less than 0.5% on benchmarks vs MHA)
  3. Good throughput - KV computation is 8x cheaper
  4. Proven at scale - Google’s PaLM 2 validated GQA in production

The math works out perfectly: at 128K context, GQA-8 needs ~39 GB KV cache vs ~312 GB for MHA. That’s the difference between “fits on 8 GPUs” and “impossible.”


ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai