<RETURN_TO_BASE

Introducing Forgetting Transformer (FoX): Revolutionizing Long-Context Language Modeling with Efficient Memory Control

Mila & Universite de Montreal researchers introduce FoX, a novel Transformer variant with learnable forget gates that improve long-context language modeling efficiency and accuracy without computational trade-offs.

The Challenge of Managing Memory in Transformers

Transformers have transformed sequence modeling by handling long-range dependencies without recurrence. They use self-attention to process all input tokens simultaneously, achieving state-of-the-art results in natural language tasks. However, unlike recurrent neural networks (RNNs), standard Transformers lack a native mechanism to forget irrelevant past information, which can lead to inefficiencies and noise accumulation in long sequences.

Limitations of Existing Approaches

Traditional RNNs use forget gates to modulate memory retention, but they struggle with long sequences due to fixed-size hidden states. Some Transformer modifications, like ALiBi, add static positional biases to simulate recency effects, but these are not adaptive to input content. Other models such as Mamba-2 and GLA introduce gating in linear attention but often sacrifice normalization and deviate from the Transformer architecture, limiting compatibility with efficient Transformer optimizations.

The Forgetting Transformer (FoX) Architecture

Researchers from Mila & Universite de Montreal and MakerMaker AI developed the Forgetting Transformer (FoX), which integrates a Forgetting Attention mechanism by inserting a scalar forget gate into the softmax attention process. This gate dynamically adjusts attention scores based on the input data, effectively down-weighting irrelevant past tokens.

FoX preserves compatibility with parallel computation and the efficient FlashAttention algorithm, ensuring minimal overhead. Two variants were introduced:

  • FoX (LLaMA-based): The base version incorporating forget gates.
  • FoX (Pro): An enhanced version featuring output normalization, output gates, and token-shifting mechanisms inspired by recent recurrent models for improved context sensitivity.

Technical Details of Forgetting Attention

The forget gate values are computed at each timestep by applying a sigmoid activation to a learned linear transformation of the input. These scalar gates bias the attention logits via a log-sum formula, modifying the softmax operation efficiently without large matrix instantiation. Each attention head maintains independent forget gate parameters.

The Pro variant adds output normalization and a key-value shift mechanism to blend current and prior tokens in a learnable way, enhancing flexibility without significantly increasing parameter count.

Experimental Results and Performance

Evaluated on the LongCrawl64 dataset (48 billion tokens), FoX consistently outperformed standard Transformers and top recurrent models in long-context language modeling tasks. It achieved:

  • Sharper per-token loss decline across token positions.
  • Significantly lower loss at position 64,000 compared to Transformer and LLaMA variants.
  • Superior perplexity scores across various validation context lengths, with less performance degradation beyond the training context size of 16,384 tokens.
  • Better extrapolation capabilities than competing models like Mamba-2 and DeltaNet.

Training used 760 million parameters and the TikToken GPT-2 tokenizer, with FoX favoring higher learning rates and smaller head dimensions, reflecting architectural robustness.

Advantages and Implications

FoX introduces a data-driven recency bias that enhances selective memory in Transformers without adding computational or memory overhead. It generalizes static biases like ALiBi by learning dynamic forget gates, leading to improved adaptability and accuracy. The Pro variant excels on downstream tasks requiring context sensitivity.

This work demonstrates that incorporating dynamic, learnable memory forgetting mechanisms into Transformer architectures is both feasible and beneficial. The approach maintains efficiency via FlashAttention compatibility, allowing practical large-scale deployment.

Key Takeaways

  • Forgetting Attention augments softmax attention with learnable forget gates.
  • Two architectural variants: FoX (LLaMA) and FoX (Pro) with extra normalization and gating.
  • FoX outperforms standard Transformers in long-context modeling on large datasets.
  • Maintains low error rates and robust perplexity even on sequences beyond 64k tokens.
  • Generalizes static biases like ALiBi via dynamic gating.
  • Hardware efficient and compatible with FlashAttention for scalable use.

For further details, check the original paper and code repositories. Stay connected via Twitter, Telegram, and LinkedIn for updates on this and related AI research.

🇷🇺

Сменить язык

Читать эту статью на русском

Переключить на Русский