Emergent Mind

Abstract

Transformer-based models have emerged as one of the most widely used architectures for natural language processing, natural language generation, and image generation. The size of the state-of-the-art models has increased steadily reaching billions of parameters. These huge models are memory hungry and incur significant inference latency even on cutting edge AI-accelerators, such as GPUs. Specifically, the time and memory complexity of the attention operation is quadratic in terms of the total context length, i.e., prompt and output tokens. Thus, several optimizations such as key-value tensor caching and FlashAttention computation have been proposed to deliver the low latency demands of applications relying on such large models. However, these techniques do not cater to the computationally distinct nature of different phases during inference. To that end, we propose LeanAttention, a scalable technique of computing self-attention for the token-generation phase (decode-phase) of decoder-only transformer models. LeanAttention enables scaling the attention mechanism implementation for the challenging case of long context lengths by re-designing the execution flow for the decode-phase. We identify that the associative property of online softmax can be treated as a reduction operation thus allowing us to parallelize the attention computation over these large context lengths. We extend the "stream-K" style reduction of tiled calculation to self-attention to enable parallel computation resulting in an average of 2.6x attention execution speedup over FlashAttention-2 and up to 8.33x speedup for 512k context lengths.

LA speedup vs state-of-the-art Attention across different context lengths, batch sizes, and attention heads.

Overview

  • LeanAttention introduces an optimized attention mechanism for the decode phase of transformer models, addressing the challenge of handling extensive context lengths efficiently.

  • By exploiting associative properties and using novel decomposition strategies like Stream-K style and LeanTiles, LeanAttention achieves significant performance gains, including up to 8.33x speedup for long contexts.

  • The technique has practical implications for reducing latency and enhancing scalability in real-time applications, and it opens up new directions for further theoretical research and development in AI.

LeanAttention: Speeding Up Attention Mechanisms for Transformer Models

Background

In the realm of AI and NLP, transformer-based models have become extremely valuable due to their performance on tasks such as text generation, machine translation, and sentiment analysis. These transformers, with self-attention mechanisms at their heart, require significant memory and computation power, particularly as contexts lengthen and models scale up to billions of parameters.

The Challenge

The real bottleneck in transformer models arises during the attention mechanism, especially when dealing with lengthy contexts. Execution time and memory usage for this process grow quadratically with sequence length. To mitigate this problem, approaches like FlashAttention and FlashAttention-2 optimize memory and computational efficiency. However, these techniques often overlook the unique computational demands of the two main phases of transformer inference: the prefill phase (where the model processes the input prompt) and the decode phase (where the model generates tokens sequentially).

Introducing LeanAttention

LeanAttention is proposed as an innovative technique to optimize the attention mechanism during the decode phase of decoder-only transformer models. The decode phase is notoriously challenging due to its sequential nature and the necessity to handle extensive context lengths efficiently. LeanAttention exploits the associative properties of the attention mechanism to achieve significant performance gains.

Critical Concepts Behind LeanAttention

1. Softmax Re-scaling as a Reduction Operation

LeanAttention reconceptualizes the softmax operation involved in attention as a form of reduction. By treating re-scaling of un-scaled attention output tensors as an associative reduction operation, LeanAttention enables parallel computation over large context lengths. This is crucial for handling decode-phase workloads effectively.

2. Stream-K Style Decomposition

Taking inspiration from optimized matrix multiplication strategies in GPU computing, LeanAttention divides the attention tasks into minimal computational units called LeanTiles. It then efficiently distributes these lean tiles across the available processing units, ensuring balanced workloads and maximizing hardware utilization. Unlike previous methods, LeanAttention maintains near 100% GPU occupancy regardless of the problem size.

Performance Gains

LeanAttention has demonstrated impressive speedups in the attention execution process. In benchmark tests, LeanAttention achieved:

  • An average speedup of 2.6x over FlashAttention-2.
  • Up to 8.33x speedup for very long context lengths (512k tokens).

Additionally, in multi-GPU environments, it showed even greater gains, confirming its scalability and efficiency for large-scale AI models.

Practical and Theoretical Implications

Practical Implications

LeanAttention's capability to handle lengthy contexts efficiently means that transformer-based models can now support richer, more coherent interactions. This improvement is particularly beneficial for applications requiring long contextual understanding, such as document search and retrieval, dialogue systems, and large-scale content generation.

  • Reduced Latency: By cutting down the execution time, LeanAttention facilitates faster responses in real-time applications.
  • Enhanced Scalability: It supports the ongoing trend of increasing model sizes and context lengths without compromising performance.

Theoretical Implications

From a theoretical perspective, LeanAttention contributes to the understanding of how matrix decomposition and associative properties in computations can be leveraged to optimize complex machine learning operations. It suggests directions for further research, such as exploring similar optimizations for other phases of model inference or extending these principles to other types of models.

What's Next?

LeanAttention opens up numerous avenues for future research and development in AI:

  • Integrating with Larger Models: Applying LeanAttention to even larger transformers and comparing its performance across different architectures.
  • Extending to Other Phases: Investigating how the associative property and Stream-K decomposition can optimize other inference phases or even training processes.
  • Multi-GPU and Distributed Systems: Further optimizing LeanAttention for more complex hardware setups, enabling seamless scalability across distributed systems.

Conclusion

LeanAttention presents a significant advance in the efficient execution of the attention mechanism within transformer-based models, particularly during the decode phase. By rethinking the softmax operation and utilizing a novel decomposition strategy, LeanAttention offers substantial performance improvements and provides a robust framework for scaling up transformer models in the future. Whether you're working with extensive text generation tasks or building models that require deep contextual understanding, LeanAttention is a step forward in making large-scale NLP more efficient and scalable.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.