Emergent Mind

Just read twice: closing the recall gap for recurrent language models

(2407.05483)
Published Jul 7, 2024 in cs.CL and cs.LG

Abstract

Recurrent LLMs that compete with Transformers in language modeling perplexity are emerging at a rapid rate (e.g., Mamba, RWKV). Excitingly, these architectures use a constant amount of memory during inference. However, due to the limited memory, recurrent LMs cannot recall and use all the information in long contexts leading to brittle in-context learning (ICL) quality. A key challenge for efficient LMs is selecting what information to store versus discard. In this work, we observe the order in which information is shown to the LM impacts the selection difficulty. To formalize this, we show that the hardness of information recall reduces to the hardness of a problem called set disjointness (SD), a quintessential problem in communication complexity that requires a streaming algorithm (e.g., recurrent model) to decide whether inputted sets are disjoint. We empirically and theoretically show that the recurrent memory required to solve SD changes with set order, i.e., whether the smaller set appears first in-context. Our analysis suggests, to mitigate the reliance on data order, we can put information in the right order in-context or process prompts non-causally. Towards that end, we propose: (1) JRT-Prompt, where context gets repeated multiple times in the prompt, effectively showing the model all data orders. This gives $11.0 \pm 1.3$ points of improvement, averaged across $16$ recurrent LMs and the $6$ ICL tasks, with $11.9\times$ higher throughput than FlashAttention-2 for generation prefill (length $32$k, batch size $16$, NVidia H100). We then propose (2) JRT-RNN, which uses non-causal prefix-linear-attention to process prompts and provides $99\%$ of Transformer quality at $360$M params., $30$B tokens and $96\%$ at $1.3$B params., $50$B tokens on average across the tasks, with $19.2\times$ higher throughput for prefill than FA2.

Recurrent models struggle with long-context memory; ordering data impacts difficulty; showing input twice helps.

Overview

  • The paper introduces JRT-Prompt and JRT-RNN, two methods to enhance memory efficiency in Recurrent Language Models (RLMs) and improve in-context learning (ICL) performance.

  • Theoretical and empirical analyses show that these methods significantly reduce order sensitivity in data processing, enhancing recall efficiency and achieving near-Transformer quality with much lower computational resources.

  • JRT-Prompt repeats input context multiple times, while JRT-RNN uses a prefix-linear-attention mechanism for prompts, achieving significant practical implications for real-world applications with memory constraints.

Analyzing Efficiency in Recurrent Language Models with JRT-Prompt and JRT-RNN

This paper presents a critical examination of memory-efficient recurrent language models (RLMs) relative to Transformer architectures, particularly focusing on challenges associated with in-context learning (ICL) quality. The authors introduce two novel methods: JRT-Prompt and JRT-RNN, designed to enhance memory utility and optimization during inference in RLMs.

Key Insights and Contributions

  1. Order Sensitivity in Data Processing: The paper identifies that the order in which data is processed significantly impacts the performance of RLMs. This order sensitivity is encapsulated in the context of the set disjointness (SD) problem, a fundamental issue in communication complexity theory.

  2. Theoretical and Empirical Analysis: The authors provide a rigorous theoretical framework alongside empirical evidence to demonstrate that the memory required to solve SD varies based on the order of data presentation. They derive that models using JRT (Just-Read-Twice) prompting can reduce this sensitivity, thereby improving recall efficiency.

  3. JRT-Prompt: JRT-Prompt introduces a simple yet effective method where the input context is repeated multiple times before model completion. This method allows the model to condition its memory on the full context across multiple passes, mitigating data-order sensitivities. The method has shown an improvement of approximately 11.0 ± 1.3 points across various RLM and ICL tasks.

  4. JRT-RNN Architecture: JRT-RNN extends the benefits of JRT-Prompt by utilizing a prefix-linear-attention to process prompts. It incorporates non-causal operations for the input prefix portion and causal operations for the output portion. This architecture achieves 99% of Transformer quality at 360M parameters and 96% at 1.3B parameters, significantly enhancing throughput with up to 19.2x higher efficiency than FlashAttention-2 for generating prompts.

Practical and Theoretical Implications

Practical Implications:

  • Efficiency: With the demonstrated gains in throughput and recall efficiency, JRT-Prompt and JRT-RNN provide substantial practical benefits in deploying RLMs for real-world applications with memory constraints.
  • Scaling: The methods are adaptable to various scales of RLMs and training tokens, ensuring scalability and consistent performance enhancements.
  • Broader Applicability: Given that JRT-Prompt can be used with off-the-shelf models, it offers broad applicability, reducing the effort to modify existing RLMs.

Theoretical Implications:

  • Memory-Recall Tradeoff: The paper elucidates a deeper understanding of the memory-recall tradeoff in RLMs, providing a foundational basis for further improvements in architecture designs.
  • Order Sensitivity Models: By invoking set disjointness, the paper makes a strong case for considering data order in system designs, diverging from typical models that treat inputs uniformly regardless of order.

Future Developments

The promising results from JRT-Prompt and JRT-RNN open avenues for further explorations:

  • Selective Repetition Strategies: Optimizing the repetition strategy for specific types of tasks could refine the balance between memory use and inference time, providing a more tailored approach.
  • Hybrid Architectures: Combining JRT-RNN with other advanced techniques, such as mixture of experts or sparsity methods, may extend the Pareto frontier of quality and efficiency.
  • Extended Context Lengths: Exploring the application of these methods to significantly longer contexts, particularly in domains requiring comprehensive document understanding, could yield further insights into maintaining efficiency at scale.

In conclusion, the paper provides substantial advancements in managing memory-efficient in-context learning in recurrent language models, emphasizing the pivotal role of data order and introducing effective solutions in JRT-Prompt and JRT-RNN. These contributions offer a sound framework for enhancing RLMs both in theoretical depth and practical deployments. Further research will likely leverage these findings to continue refining the efficiency and capability of language model architectures.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.