Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
139 tokens/sec
GPT-4o
47 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Just read twice: closing the recall gap for recurrent language models (2407.05483v1)

Published 7 Jul 2024 in cs.CL and cs.LG

Abstract: Recurrent LLMs that compete with Transformers in LLMing perplexity are emerging at a rapid rate (e.g., Mamba, RWKV). Excitingly, these architectures use a constant amount of memory during inference. However, due to the limited memory, recurrent LMs cannot recall and use all the information in long contexts leading to brittle in-context learning (ICL) quality. A key challenge for efficient LMs is selecting what information to store versus discard. In this work, we observe the order in which information is shown to the LM impacts the selection difficulty. To formalize this, we show that the hardness of information recall reduces to the hardness of a problem called set disjointness (SD), a quintessential problem in communication complexity that requires a streaming algorithm (e.g., recurrent model) to decide whether inputted sets are disjoint. We empirically and theoretically show that the recurrent memory required to solve SD changes with set order, i.e., whether the smaller set appears first in-context. Our analysis suggests, to mitigate the reliance on data order, we can put information in the right order in-context or process prompts non-causally. Towards that end, we propose: (1) JRT-Prompt, where context gets repeated multiple times in the prompt, effectively showing the model all data orders. This gives $11.0 \pm 1.3$ points of improvement, averaged across $16$ recurrent LMs and the $6$ ICL tasks, with $11.9\times$ higher throughput than FlashAttention-2 for generation prefill (length $32$k, batch size $16$, NVidia H100). We then propose (2) JRT-RNN, which uses non-causal prefix-linear-attention to process prompts and provides $99\%$ of Transformer quality at $360$M params., $30$B tokens and $96\%$ at $1.3$B params., $50$B tokens on average across the tasks, with $19.2\times$ higher throughput for prefill than FA2.

Citations (4)

Summary

  • The paper introduces novel JRT-Prompt and JRT-RNN methods that close the recall gap in recurrent language models by mitigating order sensitivity with repeated context processing.
  • The paper provides rigorous theoretical analysis and empirical evidence showing an approximate 11-point gain in recall efficiency and near-Transformer performance across RLMs.
  • The paper demonstrates practical scalability and improved throughput, achieving up to 19.2x higher efficiency compared to existing methods like FlashAttention-2.

Analyzing Efficiency in Recurrent LLMs with JRT-Prompt and JRT-RNN

This paper presents a critical examination of memory-efficient recurrent LLMs (RLMs) relative to Transformer architectures, particularly focusing on challenges associated with in-context learning (ICL) quality. The authors introduce two novel methods: JRT-Prompt and JRT-RNN, designed to enhance memory utility and optimization during inference in RLMs.

Key Insights and Contributions

  1. Order Sensitivity in Data Processing: The paper identifies that the order in which data is processed significantly impacts the performance of RLMs. This order sensitivity is encapsulated in the context of the set disjointness (SD) problem, a fundamental issue in communication complexity theory.
  2. Theoretical and Empirical Analysis: The authors provide a rigorous theoretical framework alongside empirical evidence to demonstrate that the memory required to solve SD varies based on the order of data presentation. They derive that models using JRT (Just-Read-Twice) prompting can reduce this sensitivity, thereby improving recall efficiency.
  3. JRT-Prompt: JRT-Prompt introduces a simple yet effective method where the input context is repeated multiple times before model completion. This method allows the model to condition its memory on the full context across multiple passes, mitigating data-order sensitivities. The method has shown an improvement of approximately 11.0 ± 1.3 points across various RLM and ICL tasks.
  4. JRT-RNN Architecture: JRT-RNN extends the benefits of JRT-Prompt by utilizing a prefix-linear-attention to process prompts. It incorporates non-causal operations for the input prefix portion and causal operations for the output portion. This architecture achieves 99% of Transformer quality at 360M parameters and 96% at 1.3B parameters, significantly enhancing throughput with up to 19.2x higher efficiency than FlashAttention-2 for generating prompts.

Practical and Theoretical Implications

  1. Practical Implications:
    • Efficiency: With the demonstrated gains in throughput and recall efficiency, JRT-Prompt and JRT-RNN provide substantial practical benefits in deploying RLMs for real-world applications with memory constraints.
    • Scaling: The methods are adaptable to various scales of RLMs and training tokens, ensuring scalability and consistent performance enhancements.
    • Broader Applicability: Given that JRT-Prompt can be used with off-the-shelf models, it offers broad applicability, reducing the effort to modify existing RLMs.
  2. Theoretical Implications:
    • Memory-Recall Tradeoff: The paper elucidates a deeper understanding of the memory-recall tradeoff in RLMs, providing a foundational basis for further improvements in architecture designs.
    • Order Sensitivity Models: By invoking set disjointness, the paper makes a strong case for considering data order in system designs, diverging from typical models that treat inputs uniformly regardless of order.

Future Developments

The promising results from JRT-Prompt and JRT-RNN open avenues for further explorations:

  • Selective Repetition Strategies: Optimizing the repetition strategy for specific types of tasks could refine the balance between memory use and inference time, providing a more tailored approach.
  • Hybrid Architectures: Combining JRT-RNN with other advanced techniques, such as mixture of experts or sparsity methods, may extend the Pareto frontier of quality and efficiency.
  • Extended Context Lengths: Exploring the application of these methods to significantly longer contexts, particularly in domains requiring comprehensive document understanding, could yield further insights into maintaining efficiency at scale.

In conclusion, the paper provides substantial advancements in managing memory-efficient in-context learning in recurrent LLMs, emphasizing the pivotal role of data order and introducing effective solutions in JRT-Prompt and JRT-RNN. These contributions offer a sound framework for enhancing RLMs both in theoretical depth and practical deployments. Further research will likely leverage these findings to continue refining the efficiency and capability of LLM architectures.