When a single GPU can't hold the KV cache, you distribute the sequence across multiple GPUs. Here's how ring attention enables million-token contexts.
Imagine you’re building cars. One worker can’t build the entire car alone — it’s too complex and takes too long. So you create an assembly line: each worker handles one part, passes it to the next.
But what if you have a car that’s too long to fit in one workstation? You split the car into sections, with each workstation working on its section simultaneously, occasionally passing parts between stations.
That’s sequence parallelism. When the context (sequence) is too long for one GPU, you split it across multiple GPUs and coordinate the work.
There are three main ways to distribute LLM work across GPUs:
Split each layer’s weights across GPUs. Every GPU processes the full sequence but only part of each matrix multiplication.
Good for: Large models that don’t fit on one GPU Limitation: Doesn’t help with long sequences (each GPU still sees full context)
Assign different layers to different GPUs. GPU 1 runs layers 1–20, GPU 2 runs layers 21–40, etc.
Good for: Very deep models Limitation: Sequential dependency between stages creates bubbles
Split the sequence across GPUs. Each GPU holds a chunk of the context.
Good for: Long contexts that exceed single-GPU memory This is where Ring Attention comes in.
Self-attention requires every token to attend to every other token. If you split the sequence across GPUs, each GPU only has a portion of the keys and values. How does GPU 1 compute attention over keys that live on GPU 4?
Token on GPU 1 needs to compute dot products with token on GPU 4. The keys must somehow travel between GPUs.
Ring Attention (Liu et al., 2023) arranges GPUs in a ring topology. Each GPU:
def ring_attention(
local_Q, # This GPU's queries (n/P × d)
local_K, # This GPU's keys (n/P × d)
local_V, # This GPU's values (n/P × d)
n_gpus, # Number of GPUs in the ring
gpu_rank, # This GPU's rank (0 to P-1)
):
"""
Ring Attention: each GPU computes attention over
all K,V blocks by passing them around a ring.
Total communication: each GPU sends/receives K,V
(n_gpus - 1) times.
"""
n_local = local_Q.shape[0]
d = local_Q.shape[1]
# Initialize output accumulator and softmax stats
O = zeros(n_local, d) # Output
m = full(n_local, -float('inf')) # Running max
l = zeros(n_local) # Running sum
# Current K, V being processed
current_K = local_K
current_V = local_V
for step in range(n_gpus):
# ========================================
# OVERLAP: Communication + Computation
# ========================================
# Start async send/receive (NON-BLOCKING)
if step < n_gpus - 1:
# Send current K, V to next GPU
send_async(current_K, current_V, dest=(gpu_rank + 1) % n_gpus)
# Receive K, V from previous GPU
next_K, next_V = recv_async(src=(gpu_rank - 1) % n_gpus)
# COMPUTE: attention between local Q and current K, V
# This runs WHILE communication happens
S = local_Q @ current_K.T / sqrt(d) # (n_local × n_local)
# Online softmax update (same as Flash Attention)
m_new = maximum(m, S.max(dim=-1))
correction = exp(m - m_new)
P = exp(S - m_new.unsqueeze(-1))
l_new = correction * l + P.sum(dim=-1)
O = (correction.unsqueeze(-1) * l.unsqueeze(-1) * O + P @ current_V) / l_new.unsqueeze(-1)
m = m_new
l = l_new
# Wait for communication to complete
if step < n_gpus - 1:
wait_communication()
current_K = next_K
current_V = next_V
return O
The magic of ring attention is that while GPU is computing attention over the current K,V block, it’s simultaneously sending that block to GPU and receiving a new block from GPU .
If computation takes longer than communication (which it usually does for large blocks), the communication is completely hidden:
Each GPU sends its local K and V blocks around the ring:
Total sends per GPU: (receives every other GPU’s data exactly once).
For GPUs with NVLink bandwidth :
For 8 GPUs, , , FP16 (), NVLink at 900 GB/s:
With ring attention, each GPU stores:
Total per GPU:
For 8 GPUs, , (model dim), FP16:
Compare to single-GPU: .
Ring attention reduces per-GPU memory by .
DeepSpeed Ulysses takes a different approach to sequence parallelism:
Instead of passing K,V around a ring, it uses all-to-all communication to redistribute the sequence:
Ring Attention vs. DeepSpeed Ulysses:
| Property | Ring Attention | DeepSpeed Ulysses |
|---|---|---|
| Communication pattern | Point-to-point (ring) | All-to-all |
| Communication volume | ||
| Overlap with compute | Yes (natural) | Harder |
| Implementation complexity | Moderate | Lower |
| Best for | Very long sequences | Moderate sequences with many heads |
The maximum context length is now bounded by total GPU memory across all nodes:
Where is the tensor parallelism degree (for model weight distribution).
With 64 H100 GPUs (80GB each), Llama 70B:
Theoretically, 40 million tokens of context — though attention dilution would make most of that useless.
Today’s production systems typically combine:
This stack enables the 1M+ context windows offered by frontier models — but the engineering complexity is enormous, which is why only a handful of companies can operate at this scale.
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai