Emergent Mind

HMT: Hierarchical Memory Transformer for Long Context Language Processing

(2405.06067)
Published May 9, 2024 in cs.CL and cs.LG

Abstract

Transformer-based LLMs (LLM) have been widely used in language processing applications. However, most of them restrict the context window that permits the model to attend to every token in the inputs. Previous works in recurrent models can memorize past tokens to enable unlimited context and maintain effectiveness. However, they have "flat" memory architectures, which have limitations in selecting and filtering information. Since humans are good at learning and self-adjustment, we speculate that imitating brain memory hierarchy is beneficial for model memorization. We propose the Hierarchical Memory Transformer (HMT), a novel framework that enables and improves models' long-context processing ability by imitating human memorization behavior. Leveraging memory-augmented segment-level recurrence, we organize the memory hierarchy by preserving tokens from early input token segments, passing memory embeddings along the sequence, and recalling relevant information from history. Evaluating general language modeling (Wikitext-103, PG-19) and question-answering tasks (PubMedQA), we show that HMT steadily improves the long-context processing ability of context-constrained and long-context models. With an additional 0.5% - 2% of parameters, HMT can easily plug in and augment future LLMs to handle long context effectively. Our code is open-sourced on Github: https://github.com/OswaldHe/HMT-pytorch.

Overview

  • The Hierarchical Memory Transformer (HMT) is designed to extend the capabilities of conventional transformers for processing long contexts by mimicking the hierarchical structure of human memory.

  • HMT employs a memory reinforcement mechanism that includes sensory, short-term, and long-term memory banks to effectively recall and integrate information from previous segments.

  • Experimental results demonstrate HMT's significant improvements in tasks such as language modeling and question-answering, and it can be integrated into pre-existing models with minimal modifications.

Hierarchical Memory Transformer (HMT) for Long Context Processing

Introduction

Transformers have revolutionized NLP, but they do have a limitation: the maximum length of the context they can handle. The typical transformer models, including popular LLMs like Llama 2, process a fixed number of tokens at a time and are not well-suited for tasks requiring very long contexts, such as book summarization or document-based question answering.

The Hierarchical Memory Transformer (HMT) proposes a novel approach to extend the capabilities of transformers for long-context scenarios. It does so by mimicking how human memory works, utilizing a memory-augmented segment-level recurrence to handle longer contexts more effectively.

Hierarchical Memorization in HMT

HMT is designed to imitate the hierarchical structure of human memory, which consists of sensory, short-term, and long-term memory:

  • Sensory Memory: HMT uses the last few token embeddings from the previous segment, allowing it to process information that is immediately relevant.
  • Short-term Memory: Each segment is summarized into a single embedding. This summarized embedding is then used to recall relevant information from previously processed segments.
  • Long-term Memory: HMT maintains a cache of the most recent memory embeddings, effectively transforming it into a long-term memory bank. This cached memory is utilized to recall and integrate information from distant past segments.

Memory Recall Mechanism

The memory recall mechanism is one of the key innovations in HMT. It involves three main steps:

  1. Representation Extraction: The initial part of a segment is used to generate an embedding that summarizes the segment.
  2. Memory Search: This summary embedding is then used as a query to find the most relevant information from the cache of previous memory embeddings using a cross-attention mechanism.
  3. Augmenting Current Segment: The current segment is augmented with the recalled memory before being processed by the transformer model.

Training and Fine-tuning

The training process of HMT is divided into two stages to enhance efficiency:

  1. Initial Training: The model is trained to handle a few unrolled segments without memory recall.
  2. Extended Training: The pre-trained model is then extended with the memory recall mechanism and trained with a larger number of segments.

This multi-stage strategy allows HMT to train faster and achieve better performance on long-context tasks compared to single-stage training.

Experimental Results

HMT was tested using various datasets and transformer models to validate its effectiveness:

  • General Language Modeling: In tests with models such as OPT 2.7B and OpenLlamaV2 3B on Wikitext-103 and PG-19, HMT showed significant improvements. For OPT 2.7B, for example, HMT achieved a 25.5% decrease in perplexity on Wikitext-103, indicating much better language modeling performance over long contexts.
  • Question-Answering Tasks: With the PubMedQA dataset, HMT not only improved long-answer contextual reasoning by 9.81%, but also increased short-answer prediction accuracy by 1.0%.

Practical Implications

HMT offers several practical benefits:

  • Model Independence: HMT can be applied to any pre-trained model without altering the core architecture. This makes it a versatile enhancement for various transformer-based models.
  • Efficiency in Handling Long Contexts: By effectively managing long contexts with minimal additional parameters (0.5% to 2%), HMT is suitable for wide applications from book summarization to legal document processing.
  • Scalability: HMT can be scaled to even larger models and longer contexts with efficient GPU memory management techniques.

Speculations on Future Development

HMT opens the door for further innovations in memory-augmented neural networks:

  • Integrated Memory Hierarchies: Future developments could explore even more sophisticated memory hierarchies or adaptive memory management systems.
  • Enhancing Retrieval-Augmented Models: Combining HMT with other retrieval-augmented techniques may yield even more powerful models for long-context understanding and generation tasks.
  • Edge Device Deployment: Optimizations for deploying HMT on edge devices could unlock its potential for real-time applications in resource-constrained environments.

Conclusion

HMT represents a step forward in the handling of long contexts by language models, leveraging a memory system inspired by human cognition. It blends the strengths of recurrent models and transformers to robustly process long documents and text sequences, providing a valuable tool for a broad range of NLP applications.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube