Emergent Mind

Abstract

Huge memory consumption has been a major bottleneck for deploying high-throughput LLMs in real-world applications. In addition to the large number of parameters, the key-value (KV) cache for the attention mechanism in the transformer architecture consumes a significant amount of memory, especially when the number of layers is large for deep language models. In this paper, we propose a novel method that only computes and caches the KVs of a small number of layers, thus significantly saving memory consumption and improving inference throughput. Our experiments on LLMs show that our method achieves up to 26$\times$ higher throughput than standard transformers and competitive performance in language modeling and downstream tasks. In addition, our method is orthogonal to existing transformer memory-saving techniques, so it is straightforward to integrate them with our model, achieving further improvement in inference efficiency. Our code is available at https://github.com/whyNLP/LCKV.

Latency per token and memory consumption of StreamingLLM vs. integrated model with varying cache sizes.

Overview

  • The paper proposes a novel method to reduce memory consumption for LLMs during real-time inference by focusing on the key-value (KV) cache in transformer architectures.

  • The technique involves computing and storing KVs only for the final layer, leading to significant memory savings and higher throughput without substantial performance degradation.

  • Experimental results show that this approach allows for larger batch sizes and faster processing with minimal impact on model accuracy, making it highly beneficial for practical applications.

Efficient Memory Reduction for LLMs with Layer-Condensed KV Cache

Introduction

Deploying LLMs for real-time applications often stumbles upon a hefty bottleneck: memory consumption. Specifically, a significant chunk of this memory goes towards storing the key-value (KV) cache for the attention layers in transformer architectures. To put this into perspective, the KV cache can take up more than 30% of the GPU memory during model deployment.

Now, what if you could achieve the same performance with a fraction of the memory load? This paper proposes a novel approach to do just that by reducing the number of layers that compute and cache their KVs, focusing primarily on the final layer. This method can save over 86% of the KV cache memory, leading to up to 26 times higher throughput with minimal performance degradation.

The Layer-Condensed KV Cache Idea

Model

In a typical transformer, each layer of the model computes and stores its KVs during inference. This new method proposes that the queries of all layers are paired only with the KVs of the top layer. By doing so:

  1. Memory Reduction: Only the top layer's KVs need to be calculated and cached, significantly saving memory.
  2. Computation Overhead: Eliminates the need to compute KVs for all other layers, reducing the computational burden.

One challenge here is that since each token also attends to itself, its top-layer KVs are needed for lower layers' computations. This cyclic dependency is resolved by masking the diagonal of the attention matrix, meaning a token doesn't attend to itself directly.

Training

Training this model introduces a bit of complexity as it requires an iterative computation process. Here's a breakdown:

  1. From Sequential to Parallel Training: Initially, each token relies on the previous token's top-layer KVs. This dependency can be managed by performing $n$ iterations of bottom-up transformer computations.
  2. Gradient Stopping: To handle large memory consumption from backpropagation, the method uses gradient stopping, backpropagating the loss only through the last few iterations.
  3. Fast Convergence: Observations show that KVs converge quickly over iterations, so instead of running many iterations, a few suffice for good approximation.

Evaluation and Results

Generation Throughput

Experiments were conducted on different setups, including models with 1.1B, 7B, and 30B parameters using RTX 3090 (24GB) and A100 (80GB) GPUs. The results were notable:

  • Batch Sizes: The proposed method achieved up to 32 times larger batch sizes.
  • Throughput: A significant boost in throughput, with some configurations showing up to 26 times the improvement in tokens per second.

For instance, a 7B model on an A100 GPU achieved up to 421.02 tokens/s compared to the standard Llama’s 141.10 tokens/s.

Model Performance

Pre-training and evaluation of a 1.1B model using this method showed competitive performance:

Integration with StreamingLLM

This method is also flexible enough to integrate with other memory-saving techniques like StreamingLLM. Such integration achieved lower latency and memory consumption while maintaining stable performance on long sequences up to four million tokens.

Implications

Practical Implications

  • Deployment: Useful for deploying LLMs in memory-constrained environments, achieving real-time performance.
  • Cost Efficiency: Reducing the need for high-memory GPUs can lead to cost savings in cloud deployments and broader accessibility.

Theoretical Implications

  • Transformer Design: This approach offers a fresh perspective on transformer design, focusing on the utility of token representations primarily from the top layer.
  • Further Research: Opens up avenues for combining multiple memory-saving techniques for even greater efficiency.

Future Directions

The promising results suggest multiple future research directions:

  • Efficient Training Approaches: Developing even more efficient training methods to reduce training time.
  • New Kernels: Designing computational kernels that can handle larger batch sizes effectively.
  • Larger Models: Testing this approach on even larger and more complex LLMs.

Conclusion

In summary, the proposed method significantly improves the inference efficiency of LLMs by reducing KV cache memory consumption through a clever reduction in the number of cached layers. The results show that it’s possible to achieve high throughputs with minimal performance trade-offs, making it a strong candidate for practical, real-world LLM deployments.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube