Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
144 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

The Mamba in the Llama: Distilling and Accelerating Hybrid Models (2408.15237v4)

Published 27 Aug 2024 in cs.LG and cs.AI

Abstract: Linear RNN architectures, like Mamba, can be competitive with Transformer models in LLMing while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best 8B scale instruction-tuned linear RNN model. We also find that the distilled model has natural length extrapolation, showing almost perfect accuracy in the needle-in-a-haystack test at 20x the distillation length. Code and pre-trained checkpoints are open-sourced at https://github.com/jxiw/MambaInLlama and https://github.com/itsdaniele/speculative_mamba.

Citations (6)

Summary

  • The paper demonstrates a novel distillation method that converts large Transformers into linear RNNs by reusing attention layer projections.
  • It introduces a hybrid model that integrates selective attention layers, achieving near-original performance while outperforming other open-source models.
  • A hardware-aware speculative decoding algorithm is employed to accelerate inference, yielding competitive benchmarks against state-of-the-art instruction-tuned models.

Distilling and Accelerating Hybrid Models: A Comprehensive Evaluation

The paper "The Mamba in the Llama: Distilling and Accelerating Hybrid Models" by Junxiong Wang et al. presents a detailed paper on distilling large-scale Transformer models into linear RNN models, specifically using the Mamba architecture, and enhancing their inference efficiency via speculative decoding. This research is positioned at the convergence of two key challenges in the deployment of LLMs: reducing computational overhead without sacrificing model performance and accelerating inference for practical applications.

Summary of Contributions

  1. Distillation using Linear RNNs: The authors focus on converting pretrained large-scale Transformer models into linear RNNs. They demonstrate that linear RNN architectures, such as Mamba, can compete effectively with Transformers in LLMing tasks. They address the technical challenge of distilling these models by reusing the linear projection weights from the attention layers of the original Transformer models.
  2. Hybrid Model Performance: The resulting hybrid model, which incorporates a subset of the attention layers, is designed. This model achieves performance close to the original Transformer in various benchmarks and outperforms other open-source hybrid Mamba models.
  3. Improved Inference Speed: To address the quadratic complexity and large key-value (KV) cache requirements of Transformers, the authors introduce a hardware-aware speculative decoding algorithm. This method significantly accelerates inference speed for Mamba and hybrid models.
  4. Empirical Validation: The top-performing model, distilled from Llama3-8B-Instruct, demonstrates strong comparative performance. It achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Theoretical Implications

The paper establishes that linearizing attention mechanisms into linear RNNs can retain much of the original model's generation quality. This finding is critical, as it suggests that with appropriate distillation techniques, the benefits of Transformer architectures can be transferred to more computationally efficient models. The authors use a modified Mamba architecture initialized from the attention blocks of a pretrained model, highlighting the natural relationship between multihead attention and linear RNN formulations.

Practical Implications

Practically, the research presents a pathway to deploying LLMs on hardware with limited resources. The speculative decoding algorithm is particularly noteworthy. By leveraging a draft and verification model paradigm, the algorithm achieves substantial speedups in token generation without materializing intermediate states, which optimizes memory usage—a crucial consideration for hardware deployments.

Future Directions

This paper opens several avenues for further research in optimizing LLM deployment:

  • Exploring Smaller Models: Future studies could investigate the efficacy of these distillation techniques on smaller-scale Transformer models to broaden their applicability.
  • Advanced Speculative Decoding Techniques: Developing even more efficient multi-step verification methods could further enhance inference speed.
  • Automated Distillation Processes: Research into automating the distillation process could make these techniques more accessible and widely used.

Conclusion

The work by Wang et al. provides compelling evidence that Transformer models’ prowess can be distilled into linear RNN architectures like Mamba without significant loss in performance. By combining this with efficient speculative decoding algorithms, they present a balanced solution to the computational demands of deploying LLMs. This paper stands as a testament to the ongoing evolution in deep learning methods, advocating for more efficient and practical AI systems.

Youtube Logo Streamline Icon: https://streamlinehq.com

HackerNews

  1. The Mamba in the Llama (2 points, 0 comments)