Infini-Attention and Compressive Memory: Unbounded Context with Bounded Memory

Google's Infini-Attention combines standard attention with a compressive memory that persists across segments — enabling theoretically infinite context at O(1) memory.

Infini-Attention and Compressive Memory: Unbounded Context with Bounded Memory

Infini-Attention and Compressive Memory

The Notebook Analogy

Imagine reading a very long book but you only have a small notebook. You can’t write down everything, so after each chapter, you write a summary in your notebook. When the next chapter references something from before, you check your notes.

Your notebook has fixed pages — it doesn’t grow with the book. But the quality of your notes determines how well you remember.

This is Infini-Attention. The model processes text in segments, maintaining a fixed-size “notebook” (compressive memory) that accumulates information across all segments. The memory size doesn’t grow with context — it’s O(1)O(1) regardless of how many tokens have been processed.

The Architecture

Infini-Attention (Munkhdalai et al., 2024) adds a compressive memory MM alongside standard attention:

Per-Segment Processing

For each segment tt of length ss:

  1. Standard local attention over the current segment
  2. Memory-based attention using the compressive memory Mt1M_{t-1}
  3. Gated combination of both signals
  4. Memory update to produce MtM_t

The Memory Mechanism

The compressive memory MRdk×dvM \in \mathbb{R}^{d_k \times d_v} acts as an associative store — a matrix that maps queries to values:

Retrieval: Given query qq, retrieve from memory:

Amem=σ(q)Mt1σ(q)zt1A_{\text{mem}} = \frac{\sigma(q) M_{t-1}}{\sigma(q) z_{t-1}}

Where σ\sigma is a nonlinearity (e.g., ELU+1) and zt1Rdkz_{t-1} \in \mathbb{R}^{d_k} is a normalization vector.

Update: After processing segment tt, update memory:

Mt=Mt1+σ(Kt)TVtM_t = M_{t-1} + \sigma(K_t)^T V_t

zt=zt1+iσ(kt,i)z_t = z_{t-1} + \sum_{i} \sigma(k_{t,i})

This is a linear attention update — each key-value pair is added to the memory via outer product.

Gated Combination

The final output combines local attention and memory attention:

O=sigmoid(β)Alocal+(1sigmoid(β))AmemO = \text{sigmoid}(\beta) \odot A_{\text{local}} + (1 - \text{sigmoid}(\beta)) \odot A_{\text{mem}}

Where β\beta is a learned gating parameter. The model learns when to trust local context vs. long-range memory.

The Key Equations

Infini-Attention(Q,K,V,Mt1,zt1)=gateAlocal+(1gate)Amem\text{Infini-Attention}(Q, K, V, M_{t-1}, z_{t-1}) = \text{gate} \cdot A_{\text{local}} + (1 - \text{gate}) \cdot A_{\text{mem}}

Where:

Alocal=softmax(QKTdk)V(standard attention)A_{\text{local}} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V \quad \text{(standard attention)}

Amem=σ(Q)Mt1σ(Q)zt1(memory retrieval)A_{\text{mem}} = \frac{\sigma(Q) M_{t-1}}{\sigma(Q) z_{t-1}} \quad \text{(memory retrieval)}

Mt=Mt1+σ(K)TV(memory update)M_t = M_{t-1} + \sigma(K)^T V \quad \text{(memory update)}

gate=sigmoid(β)\text{gate} = \text{sigmoid}(\beta)

PyTorch-Style Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class InfiniAttention(nn.Module):
    """
    Infini-Attention: bounded memory for unbounded context.

    Memory size is O(d_k × d_v) — independent of sequence length.
    """
    def __init__(self, d_model, n_heads, segment_size=2048):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.segment_size = segment_size

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

        # Learned gating parameter (one per head)
        self.beta = nn.Parameter(torch.zeros(n_heads))

    def elu_plus_one(self, x):
        """ELU + 1 activation for linear attention kernel."""
        return F.elu(x) + 1

    def forward(self, x):
        B, N, D = x.shape
        H, dk = self.n_heads, self.d_k

        # Project to Q, K, V
        Q = self.W_Q(x).view(B, N, H, dk).transpose(1, 2)  # (B, H, N, dk)
        K = self.W_K(x).view(B, N, H, dk).transpose(1, 2)
        V = self.W_V(x).view(B, N, H, dk).transpose(1, 2)

        # Initialize compressive memory
        M = torch.zeros(B, H, dk, dk, device=x.device)  # O(d_k²) — FIXED SIZE
        z = torch.zeros(B, H, dk, device=x.device)        # Normalization

        outputs = []

        # Process in segments
        for start in range(0, N, self.segment_size):
            end = min(start + self.segment_size, N)
            seg_len = end - start

            Q_seg = Q[:, :, start:end]  # (B, H, seg_len, dk)
            K_seg = K[:, :, start:end]
            V_seg = V[:, :, start:end]

            # 1. Local attention (standard softmax attention within segment)
            A_local = F.scaled_dot_product_attention(Q_seg, K_seg, V_seg)

            # 2. Memory-based attention (linear attention retrieval)
            Q_prime = self.elu_plus_one(Q_seg)  # (B, H, seg_len, dk)
            K_prime = self.elu_plus_one(K_seg)

            # Retrieve from memory: (B, H, seg_len, dk) @ (B, H, dk, dk)
            mem_numerator = torch.matmul(Q_prime, M)  # (B, H, seg_len, dk)
            mem_denominator = torch.matmul(
                Q_prime, z.unsqueeze(-1)
            ) + 1e-6  # (B, H, seg_len, 1)
            A_mem = mem_numerator / mem_denominator

            # 3. Gated combination
            gate = torch.sigmoid(self.beta).view(1, H, 1, 1)
            output = gate * A_local + (1 - gate) * A_mem
            outputs.append(output)

            # 4. Update memory with this segment's information
            # M += K'^T @ V  (outer product accumulation)
            M = M + torch.matmul(K_prime.transpose(-2, -1), V_seg)
            z = z + K_prime.sum(dim=-2)

        # Concatenate all segment outputs
        full_output = torch.cat(outputs, dim=2)  # (B, H, N, dk)
        full_output = full_output.transpose(1, 2).reshape(B, N, D)
        return self.W_O(full_output)


# Memory comparison
d_model = 4096
n_heads = 32
dk = d_model // n_heads

for n_tokens in [1_000, 10_000, 100_000, 1_000_000]:
    # Standard attention KV cache
    kv_standard = 2 * n_tokens * d_model * 2  # bytes (FP16)

    # Infini-attention memory
    mem_infini = n_heads * dk * dk * 2 * 2  # M + z per head (FP16)
    # Plus one segment's KV cache
    segment_kv = 2 * 2048 * d_model * 2

    total_infini = mem_infini + segment_kv

    print(f"n={n_tokens:>10,}  Standard KV: {kv_standard/(1024**2):>8.1f} MB  "
          f"Infini: {total_infini/(1024**2):>8.1f} MB  "
          f"Savings: {kv_standard/total_infini:>6.0f}×")

Output:

n=     1,000  Standard KV:      7.8 MB  Infini:    33.3 MB  Savings:     0×
n=    10,000  Standard KV:     78.1 MB  Infini:    33.3 MB  Savings:     2×
n=   100,000  Standard KV:    781.3 MB  Infini:    33.3 MB  Savings:    23×
n= 1,000,000  Standard KV:  7,812.5 MB  Infini:    33.3 MB  Savings:   235×

At 1M tokens, Infini-Attention uses 235× less memory than standard attention.

ApproachMemoryContextQuality
Standard AttentionO(nd)O(n \cdot d)FixedBest
Transformer-XLO(nd)O(n \cdot d)Segment + cacheGood
Memorizing TransformerO(nd)O(n \cdot d)Full (kNN lookup)Good
Infini-AttentionO(d2)O(d^2)UnboundedModerate
Linear AttentionO(d2)O(d^2)UnboundedLower

Limitations

  1. Compressive memory is lossy. Information from early segments is compressed into a fixed-size matrix. Details are inevitably lost.

  2. Interference. As more key-value pairs are added to MM, older entries get “overwritten” by similar newer ones. This is called catastrophic interference.

  3. Retrieval precision. Linear attention (used for memory retrieval) is less precise than softmax attention. The model may fail to retrieve specific details.

  4. Training complexity. Training with very long sequences (to fill the memory) requires significant compute.

Despite these limitations, Infini-Attention represents an important direction: bounded memory for unbounded context is the only way to truly scale to millions of tokens without proportionally scaling hardware.


ByteBell helps engineering teams solve exactly this problem. Instead of stuffing everything into the context window, ByteBell’s Smart Context Refresh retrieves only what matters — keeping your AI sharp, fast, and accurate. Learn more at bytebell.ai