Standard multi-head attention uses separate K and V for each head. MQA and GQA share them — reducing KV cache dramatically with minimal quality loss.
In a class of 64 students (attention heads), imagine each student needs their own copy of the textbook (KV pair). That’s 64 textbooks. Expensive!
Multi-Query Attention (MQA): All 64 students share ONE textbook. Cheap, but students are often waiting for their turn.
Grouped-Query Attention (GQA): Groups of 8 students share one textbook. 8 textbooks total — much cheaper than 64, and the wait is manageable.
This is exactly how modern LLMs reduce KV cache memory.
In standard MHA, each head has its own Q, K, V projections:
KV cache per layer: elements.
For a model with heads, , at tokens:
With 80 layers: 335 GB just for KV cache.
MQA (Shazeer, 2019): All heads share a single K and V:
Each head still has its own query, so they “ask different questions.” But they all look at the same keys and values.
Reduction factor:
MQA reduces memory by 64× but each head sees the same KV, limiting the diversity of information each head can retrieve. Empirical results show 1-3% quality degradation — small but not zero.
GQA (Ainslie et al., 2023): A middle ground. Heads are divided into groups, each sharing one KV pair:
With heads and groups: each group has 8 heads sharing one KV.
Reduction factor:
| Method | KV Heads | KV Cache (per layer) | Quality | Reduction |
|---|---|---|---|---|
| MHA | 64 | Baseline | 1× | |
| GQA-8 | 8 | ~99% | 8× | |
| GQA-4 | 4 | ~98% | 16× | |
| GQA-2 | 2 | ~97% | 32× | |
| MQA | 1 | ~97% | 64× |
import numpy as np
def compare_attention_variants(
n_query_heads: int = 64,
head_dim: int = 128,
seq_length: int = 128_000,
n_layers: int = 80,
precision_bytes: int = 2,
):
"""Compare MHA, GQA, and MQA KV cache requirements."""
variants = [
("MHA", n_query_heads),
("GQA-16", 16),
("GQA-8", 8),
("GQA-4", 4),
("GQA-2", 2),
("MQA", 1),
]
print(f"Config: {n_query_heads} Q-heads, d_h={head_dim}, "
f"seq={seq_length:,}, {n_layers} layers, FP16")
print()
print(f"{'Variant':<10} {'KV Heads':>10} {'KV/Layer':>10} "
f"{'Total KV':>10} {'Reduction':>10}")
print("=" * 55)
base_mem = None
for name, kv_heads in variants:
per_layer = 2 * kv_heads * head_dim * seq_length * precision_bytes
total = per_layer * n_layers
total_gb = total / (1024**3)
if base_mem is None:
base_mem = total_gb
reduction = base_mem / total_gb
print(f"{name:<10} {kv_heads:>10} {per_layer/(1024**3):>9.1f}G "
f"{total_gb:>9.1f}G {reduction:>10.0f}×")
compare_attention_variants()import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""Standard Multi-Head Attention (MHA)."""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model) # Full size
self.W_V = nn.Linear(d_model, d_model) # Full size
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
Q = self.W_Q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
attn = F.scaled_dot_product_attention(Q, K, V)
return self.W_O(attn.transpose(1, 2).reshape(B, N, D))
class GroupedQueryAttention(nn.Module):
"""Grouped-Query Attention (GQA)."""
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_groups = n_heads // n_kv_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, n_heads * self.d_k)
self.W_K = nn.Linear(d_model, n_kv_heads * self.d_k) # Smaller!
self.W_V = nn.Linear(d_model, n_kv_heads * self.d_k) # Smaller!
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
Q = self.W_Q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(B, N, self.n_kv_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(B, N, self.n_kv_heads, self.d_k).transpose(1, 2)
# Expand KV heads to match Q heads by repeating
K = K.repeat_interleave(self.n_groups, dim=1) # (B, n_heads, N, d_k)
V = V.repeat_interleave(self.n_groups, dim=1)
attn = F.scaled_dot_product_attention(Q, K, V)
return self.W_O(attn.transpose(1, 2).reshape(B, N, D))
class MultiQueryAttention(nn.Module):
"""Multi-Query Attention (MQA) — GQA with 1 KV head."""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, n_heads * self.d_k)
self.W_K = nn.Linear(d_model, self.d_k) # Single head!
self.W_V = nn.Linear(d_model, self.d_k) # Single head!
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
Q = self.W_Q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(B, N, 1, self.d_k).transpose(1, 2)
V = self.W_V(x).view(B, N, 1, self.d_k).transpose(1, 2)
# Broadcast single KV head to all Q heads
K = K.expand(-1, self.n_heads, -1, -1)
V = V.expand(-1, self.n_heads, -1, -1)
attn = F.scaled_dot_product_attention(Q, K, V)
return self.W_O(attn.transpose(1, 2).reshape(B, N, D))
# Parameter count comparison
d_model = 8192
n_heads = 64
mha = MultiHeadAttention(d_model, n_heads)
gqa = GroupedQueryAttention(d_model, n_heads, n_kv_heads=8)
mqa = MultiQueryAttention(d_model, n_heads)
for name, module in [("MHA", mha), ("GQA-8", gqa), ("MQA", mqa)]:
kv_params = sum(p.numel() for n, p in module.named_parameters()
if 'W_K' in n or 'W_V' in n)
total_params = sum(p.numel() for p in module.parameters())
print(f"{name:<8} KV params: {kv_params:>12,} Total: {total_params:>12,}")Llama 3 (70B) uses 64 query heads and 8 KV heads (). The reasoning:
The math works out perfectly: at 128K context, GQA-8 needs ~39 GB KV cache vs ~312 GB for MHA. That’s the difference between “fits on 8 GPUs” and “impossible.”
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai