Emergent Mind

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

(2407.04620)
Published Jul 5, 2024 in cs.LG , cs.AI , and cs.CL

Abstract

Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.

Sequence modeling layers as hidden states transitioning via test-time training (TTT) on self-supervised loss.

Overview

  • The paper introduces a novel architecture called Test-Time Training (TTT) layers, which improve the expressiveness of Recurrent Neural Networks (RNNs) for sequence modeling by using a machine learning model as the hidden state.

  • The TTT layers employ self-supervised learning principles, enabling the model to adapt and learn during test time, which is demonstrated through two key instantiations: TTT-Linear and TTT-MLP.

  • Empirical evaluations reveal that TTT layers outperform existing RNNs and achieve performance comparable or superior to Transformer's self-attention mechanism, particularly in handling long-context sequences with computational efficiency.

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

The paper presents a novel architecture for sequence modeling called Test-Time Training (TTT) layers, addressing the challenge of modeling long context sequences efficiently. The primary innovation lies in making the hidden state a machine learning model itself, facilitating an update rule akin to a step of self-supervised learning, maintaining linear complexity.

Background and Motivation

Recurrent Neural Networks (RNNs) traditionally suffer from limitations in expressive power when dealing with long sequences, although they offer linear complexity. In contrast, Transformer's self-attention mechanism, despite its superior performance in long contexts, incurs quadratic computational complexity. The authors aim to strike a balance between these paradigms by enhancing the expressiveness of RNNs while maintaining their linear computational advantages.

TTT Layer Framework

The crux of the proposed method is to employ a machine learning model as the hidden state in sequence modeling layers. The update rule for these layers involves a gradient step on a self-supervised loss, allowing the model to learn even during test time. This design leads to two key instantiations: TTT-Linear and TTT-MLP. Here, the hidden states are a linear model and a two-layer Multi-Layer Perceptron (MLP), respectively, optimized in an end-to-end fashion with neural network training techniques.

Theoretical Insights and Implementation

The paper explains that all sequence modeling layers, including self-attention and existing RNNs, can be represented through a framework of hidden state transitions via update rules. The difference lies in the expressiveness of these hidden states and their respective update rules. By using self-supervised learning to compress the historical context into the hidden state, the TTT layers inherently train on-the-fly during test time. This reformulation exploits the expressive power of machine learning models to capture more sophisticated dependencies within long sequences.

For hardware efficiency, the paper introduces mini-batch TTT and dual form techniques. Mini-batch TTT enables gradient computations for multiple tokens simultaneously, balancing the trade-off between parallelism and effectiveness. The dual form improves computational efficiency by avoiding the materialization of intermediate states, optimizing operations for modern GPU architectures.

Empirical Evaluations

The evaluations across multiple contexts (2k to 32k tokens) and model scales (125M to 1.3B parameters) demonstrate that TTT-Linear and TTT-MLP layers consistently outperform existing RNN counterparts like Mamba in both perplexity and computational efficiency. For example, TTT-Linear achieves lower perplexity than Mamba while maintaining fewer FLOPs and matches or exceeds Transformer's performance, especially in long context scenarios beyond 16k tokens.

Implications and Future Directions

The results imply that TTT layers can serve as practical replacements for traditional sequence modeling components in long-context applications, potentially within LLMs. By opening up a novel approach that merges training and inference phases, this method suggests a new pathway for designing efficient sequence models.

Future research can delve into optimizing the outer-loop parameterization, extending systems optimization, and exploring more expressive forms of the inner model $f$. Given the promising results in long context handling, further scaling up model size and context length seems a viable direction. Additionally, the concept of multi-level learning to learn, wherein a self-attention layer itself becomes subject to TTT, represents an ambitious and intriguing frontier.

Conclusion

This paper presents a thoughtful and rigorous analysis of sequence modeling through the lens of test-time training, offering a conceptual and practical framework that blends the strengths of RNNs and Transformers. By iteratively developing and empirically testing TTT layers, the researchers make a strong case for their adoption in scenarios requiring efficient long-context understanding. The methodology, characterized by innovative use of self-supervised learning and hardware-conscious optimization, paves the way for future exploration in AI's capability to learn continuously and contextually.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube