Google's Infini-Attention combines standard attention with a compressive memory that persists across segments — enabling theoretically infinite context at O(1) memory.
Imagine reading a very long book but you only have a small notebook. You can’t write down everything, so after each chapter, you write a summary in your notebook. When the next chapter references something from before, you check your notes.
Your notebook has fixed pages — it doesn’t grow with the book. But the quality of your notes determines how well you remember.
This is Infini-Attention. The model processes text in segments, maintaining a fixed-size “notebook” (compressive memory) that accumulates information across all segments. The memory size doesn’t grow with context — it’s regardless of how many tokens have been processed.
Infini-Attention (Munkhdalai et al., 2024) adds a compressive memory alongside standard attention:
For each segment of length :
The compressive memory acts as an associative store — a matrix that maps queries to values:
Retrieval: Given query , retrieve from memory:
Where is a nonlinearity (e.g., ELU+1) and is a normalization vector.
Update: After processing segment , update memory:
This is a linear attention update — each key-value pair is added to the memory via outer product.
The final output combines local attention and memory attention:
Where is a learned gating parameter. The model learns when to trust local context vs. long-range memory.
Where:
import torch
import torch.nn as nn
import torch.nn.functional as F
class InfiniAttention(nn.Module):
"""
Infini-Attention: bounded memory for unbounded context.
Memory size is O(d_k × d_v) — independent of sequence length.
"""
def __init__(self, d_model, n_heads, segment_size=2048):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.segment_size = segment_size
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
# Learned gating parameter (one per head)
self.beta = nn.Parameter(torch.zeros(n_heads))
def elu_plus_one(self, x):
"""ELU + 1 activation for linear attention kernel."""
return F.elu(x) + 1
def forward(self, x):
B, N, D = x.shape
H, dk = self.n_heads, self.d_k
# Project to Q, K, V
Q = self.W_Q(x).view(B, N, H, dk).transpose(1, 2) # (B, H, N, dk)
K = self.W_K(x).view(B, N, H, dk).transpose(1, 2)
V = self.W_V(x).view(B, N, H, dk).transpose(1, 2)
# Initialize compressive memory
M = torch.zeros(B, H, dk, dk, device=x.device) # O(d_k²) — FIXED SIZE
z = torch.zeros(B, H, dk, device=x.device) # Normalization
outputs = []
# Process in segments
for start in range(0, N, self.segment_size):
end = min(start + self.segment_size, N)
seg_len = end - start
Q_seg = Q[:, :, start:end] # (B, H, seg_len, dk)
K_seg = K[:, :, start:end]
V_seg = V[:, :, start:end]
# 1. Local attention (standard softmax attention within segment)
A_local = F.scaled_dot_product_attention(Q_seg, K_seg, V_seg)
# 2. Memory-based attention (linear attention retrieval)
Q_prime = self.elu_plus_one(Q_seg) # (B, H, seg_len, dk)
K_prime = self.elu_plus_one(K_seg)
# Retrieve from memory: (B, H, seg_len, dk) @ (B, H, dk, dk)
mem_numerator = torch.matmul(Q_prime, M) # (B, H, seg_len, dk)
mem_denominator = torch.matmul(
Q_prime, z.unsqueeze(-1)
) + 1e-6 # (B, H, seg_len, 1)
A_mem = mem_numerator / mem_denominator
# 3. Gated combination
gate = torch.sigmoid(self.beta).view(1, H, 1, 1)
output = gate * A_local + (1 - gate) * A_mem
outputs.append(output)
# 4. Update memory with this segment's information
# M += K'^T @ V (outer product accumulation)
M = M + torch.matmul(K_prime.transpose(-2, -1), V_seg)
z = z + K_prime.sum(dim=-2)
# Concatenate all segment outputs
full_output = torch.cat(outputs, dim=2) # (B, H, N, dk)
full_output = full_output.transpose(1, 2).reshape(B, N, D)
return self.W_O(full_output)
# Memory comparison
d_model = 4096
n_heads = 32
dk = d_model // n_heads
for n_tokens in [1_000, 10_000, 100_000, 1_000_000]:
# Standard attention KV cache
kv_standard = 2 * n_tokens * d_model * 2 # bytes (FP16)
# Infini-attention memory
mem_infini = n_heads * dk * dk * 2 * 2 # M + z per head (FP16)
# Plus one segment's KV cache
segment_kv = 2 * 2048 * d_model * 2
total_infini = mem_infini + segment_kv
print(f"n={n_tokens:>10,} Standard KV: {kv_standard/(1024**2):>8.1f} MB "
f"Infini: {total_infini/(1024**2):>8.1f} MB "
f"Savings: {kv_standard/total_infini:>6.0f}×")Output:
n= 1,000 Standard KV: 7.8 MB Infini: 33.3 MB Savings: 0×
n= 10,000 Standard KV: 78.1 MB Infini: 33.3 MB Savings: 2×
n= 100,000 Standard KV: 781.3 MB Infini: 33.3 MB Savings: 23×
n= 1,000,000 Standard KV: 7,812.5 MB Infini: 33.3 MB Savings: 235×At 1M tokens, Infini-Attention uses 235× less memory than standard attention.
| Approach | Memory | Context | Quality |
|---|---|---|---|
| Standard Attention | Fixed | Best | |
| Transformer-XL | Segment + cache | Good | |
| Memorizing Transformer | Full (kNN lookup) | Good | |
| Infini-Attention | Unbounded | Moderate | |
| Linear Attention | Unbounded | Lower |
Compressive memory is lossy. Information from early segments is compressed into a fixed-size matrix. Details are inevitably lost.
Interference. As more key-value pairs are added to , older entries get “overwritten” by similar newer ones. This is called catastrophic interference.
Retrieval precision. Linear attention (used for memory retrieval) is less precise than softmax attention. The model may fail to retrieve specific details.
Training complexity. Training with very long sequences (to fill the memory) requires significant compute.
Despite these limitations, Infini-Attention represents an important direction: bounded memory for unbounded context is the only way to truly scale to millions of tokens without proportionally scaling hardware.
ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai