Emergent Mind

Reducing Transformer Key-Value Cache Size with Cross-Layer Attention

(2405.12981)
Published May 21, 2024 in cs.LG and cs.CL

Abstract

Key-value (KV) caching plays an essential role in accelerating decoding for transformer-based autoregressive LLMs. However, the amount of memory required to store the KV cache can become prohibitive at long sequence lengths and large batch sizes. Since the invention of the transformer, two of the most effective interventions discovered for reducing the size of the KV cache have been Multi-Query Attention (MQA) and its generalization, Grouped-Query Attention (GQA). MQA and GQA both modify the design of the attention block so that multiple query heads can share a single key/value head, reducing the number of distinct key/value heads by a large factor while only minimally degrading accuracy. In this paper, we show that it is possible to take Multi-Query Attention a step further by also sharing key and value heads between adjacent layers, yielding a new attention design we call Cross-Layer Attention (CLA). With CLA, we find that it is possible to reduce the size of the KV cache by another 2x while maintaining nearly the same accuracy as unmodified MQA. In experiments training 1B- and 3B-parameter models from scratch, we demonstrate that CLA provides a Pareto improvement over the memory/accuracy tradeoffs which are possible with traditional MQA, enabling inference with longer sequence lengths and larger batch sizes than would otherwise be possible

KV cache structures in a 10-layer transformer under different Cross-Layer Attention configurations with varying sharing factors.

Overview

  • Cross-Layer Attention (CLA) reduces the memory footprint of LLMs by sharing key-value (KV) heads across adjacent layers, allowing for significant memory savings with minimal accuracy loss.

  • CLA can achieve twice the memory efficiency of existing Multi-Query Attention (MQA) techniques, enabling the handling of longer sequences and larger batch sizes without substantial computational offloading.

  • Experimental results from 1B- and 3B-parameter models demonstrate that CLA maintains performance while reducing memory requirements, making it a potent tool for optimizing LLMs, especially when combined with MQA.

Introducing Cross-Layer Attention (CLA) for Transformer-Based LLMs

Memory Footprint and Key-Value (KV) Caching

One of the key challenges when working with LLMs is managing the memory footprint, particularly the Key-Value (KV) cache. The larger the model, the more memory it requires for storing the KV cache, which scales with both sequence length and batch size. Having a large KV cache can be a bottleneck, making it difficult to work with long sequences or larger batch sizes without offloading some of the computations.

Multi-Query and Grouped-Query Attention: A Quick Recap

Before diving into Cross-Layer Attention (CLA), it’s important to understand Multi-Query Attention (MQA) and Grouped-Query Attention (GQA). These techniques help reduce the KV cache size by enabling multiple query heads to share a single key/value head, thereby significantly cutting down the memory requirement without much loss in accuracy. In simpler terms:

  • MQA: Every query head shares one key/value head.
  • GQA: Groups of query heads share the same key/value head.

Cross-Layer Attention (CLA): The New Kid on the Block

The paper proposes taking the idea of MQAs even further by introducing Cross-Layer Attention (CLA). In essence, CLA shares KV heads not just among query heads within the same layer but also across adjacent layers. The idea here is to reduce the memory requirement even more while maintaining performance.

Key Findings:

  1. Memory Efficiency: CLA can cut the KV cache size by another 2x beyond what MQA achieves.
  2. Minimal Accuracy Loss: The accuracy remains almost as good as the traditional MQA technique.
  3. Experimental Validation: Pretraining experiments with 1B- and 3B-parameter models demonstrated that CLA achieves superior memory/accuracy trade-offs.

Practical Implications and Takeaways

For practical applications, integrating CLA into the LLM models could mean:

  • Increased Sequence Length: You could handle longer sequences without incurring a massive memory overhead.
  • Larger Batch Sizes: More efficient memory usage allows for larger batch sizes during inference.
  • General Guidance: Combining CLA with MQA is recommended for optimal memory reduction, especially with a CLA factor of 2.

How Cross-Layer Attention Works

Here’s a simplified breakdown of how CLA works:

  • Traditional transformers compute unique key/value pairs for each layer.
  • In CLA, some layers compute fresh key/value pairs, and adjacent layers reuse these pairs, reducing the total memory footprint.

You can visualize the configurations like this:

  • CLA2: Every two adjacent layers share the same KV cache.
  • CLA3: Every three adjacent layers share the same KV cache, and so on.

Extensive Experiments and Robust Results

The researchers put CLA through its paces via various pretraining experiments. Key highlights include:

1B-Parameter Scale:

  • Design Space Exploration: Training diverse CLA and non-CLA models to map out the accuracy/memory trade-offs.
  • Learning Rate Tuning: Ensuring that their results hold even with optimized learning rates for the compared models.

Results showed that:

  • MQA combined with CLA2 (MQA-CLA2) models achieved better perplexities (a measure of model accuracy) for the same KV cache memory, compared to baseline MQA models.
  • CLA2 is the most effective configuration, outperforming larger sharing factors like CLA3 or CLA4.

3B-Parameter Scale:

  • Similar experiments confirmed that the beneficial effects of CLA observed at the 1B scale hold true even at the larger 3B scale.

What’s Next? Future Directions

The potential for future work with CLA includes:

  • Longer Contexts: Evaluating CLA’s performance on models designed to handle longer sequences efficiently.
  • Incorporation with Other Techniques: Combining CLA with other memory-efficient mechanisms, such as those reducing the bandwidth or time complexity of the attention mechanism.
  • Complete System Integration: Testing CLA in a full inference system to quantify end-to-end cost reductions and efficiency improvements.

Related Work and Context

The work on CLA fits into a broader landscape of techniques aimed at improving the memory efficiency of transformers. Other efforts include:

Conclusion

Cross-Layer Attention (CLA) represents a significant step in optimizing the memory footprint of transformer-based LLMs. By sharing KV heads across layers, CLA achieves a notable reduction in memory usage with minimal accuracy trade-offs, proving to be a valuable tool for scaling models to work with longer sequences and larger batch sizes. Practitioners looking to optimize their model’s memory efficiency should definitely consider integrating CLA, particularly alongside MQA for the best results.

Feel free to dig deeper into this concept to explore how CLA could potentially benefit your specific LLM applications!

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.