Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads (2401.10774v3)
Abstract: LLMs employ auto-regressive decoding that requires sequential computation, with each step reliant on the previous one's output. This creates a bottleneck as each step necessitates moving the full model parameters from High-Bandwidth Memory (HBM) to the accelerator's cache. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present Medusa, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, Medusa constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, Medusa substantially reduces the number of decoding steps required. We present two levels of fine-tuning procedures for Medusa to meet the needs of different use cases: Medusa-1: Medusa is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. Medusa-2: Medusa is fine-tuned together with the backbone LLM, enabling better prediction accuracy of Medusa heads and higher speedup but needing a special training recipe that preserves the backbone model's capabilities. Moreover, we propose several extensions that improve or expand the utility of Medusa, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate Medusa on models of various sizes and training procedures. Our experiments demonstrate that Medusa-1 can achieve over 2.2x speedup without compromising generation quality, while Medusa-2 further improves the speedup to 2.3-3.6x.
- Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245, 2023.
- Axolotl. Axolotl. https://github.com/OpenAccess-AI-Collective/axolotl, 2023.
- {MIROSTAT}: A {neural} {text} {decoding} {algorithm} {that} {directly} {controls} {perplexity}. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=W1G1JZEIy5_.
- Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- Accelerating large language model decoding with speculative sampling. February 2023. doi: 10.48550/ARXIV.2302.01318.
- Vicuna: An open-source chatbot impressing gpt-4 with 90%* chatgpt quality, March 2023. URL https://lmsys.org/blog/2023-03-30-vicuna/.
- Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- 8-bit optimizers via block-wise quantization. International Conference on Learning Representations, 2021.
- Llm. int8 (): 8-bit matrix multiplication for transformers at scale. arXiv preprint arXiv:2208.07339, 2022.
- Qlora: Efficient finetuning of quantized llms. arXiv preprint arXiv:2305.14314, 2023.
- Enhancing chat language models by scaling high-quality instructional conversations, 2023.
- Alpacafarm: A simulation framework for methods that learn from human feedback, 2023.
- Sigmoid-weighted linear units for neural network function approximation in reinforcement learning. Neural Networks, 2017. doi: 10.1016/j.neunet.2017.12.012.
- Hierarchical neural story generation. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Association for Computational Linguistics, 2018. doi: 10.18653/v1/p18-1082.
- Gptq: Accurate post-training quantization for generative pre-trained transformers. arXiv preprint arXiv:2210.17323, 2022.
- Breaking the sequential dependency of llm inference using lookahead decoding, November 2023. URL https://lmsys.org/blog/2023-11-21-lookahead-decoding/.
- Google. Palm 2 technical report, 2023. URL https://ai.google/static/documents/palm2techreport.pdf.
- Rest: Retrieval-based speculative decoding. arXiv preprint arXiv: 2311.08252, 2023.
- Truncation sampling as language model desmoothing. October 2022. doi: 10.48550/ARXIV.2210.15191.
- Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.
- The curious case of neural text degeneration. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=rygGQyrFvH.
- Lora: Low-rank adaptation of large language models. ICLR, 2021.
- Joao Gante. Assisted generation: a new direction toward low-latency text generation, 2023. URL https://huggingface.co/blog/assisted-generation.
- Squeezellm: Dense-and-sparse quantization. arXiv preprint arXiv:2306.07629, 2023.
- Sequence-level knowledge distillation. EMNLP, 2016.
- Fine-tuning can distort pretrained features and underperform out-of-distribution. International Conference on Learning Representations, 2022.
- Efficient memory management for large language model serving with pagedattention. In Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles, 2023.
- Fast inference from transformers via speculative decoding. November 2022. doi: 10.48550/ARXIV.2211.17192.
- Awq: Activation-aware weight quantization for llm compression and acceleration. arXiv preprint arXiv:2306.00978, 2023.
- Online speculative decoding. arXiv preprint arXiv: 2310.07177, 2023.
- On the probability-quality paradox in language generation. March 2022. doi: 10.48550/ARXIV.2203.17217.
- Locally typical sampling. Transactions of the Association for Computational Linguistics, 11:102–121, 2023.
- Specinfer: Accelerating generative llm serving with speculative inference and token tree verification. arXiv preprint arXiv:2305.09781, 2023.
- OpenAI. Gpt-4 technical report, 2023.
- Training language models to follow instructions with human feedback. arXiv preprint arXiv:2203.02155, 2022.
- MAUVE: Measuring the gap between neural text and human text using divergence frontiers. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021. URL https://openreview.net/forum?id=Tqx7nJp7PR.
- Efficiently scaling transformer inference. November 2022. doi: 10.48550/ARXIV.2211.05102.
- ShareGPT. ShareGPT. https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered, 2023.
- Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150, 2019.
- Accelerating llm inference with staged speculative decoding. arXiv preprint arXiv:2308.04623, 2023.
- Blockwise parallel decoding for deep autoregressive models. Neural Information Processing Systems, 2018.
- Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
- Zephyr: Direct distillation of lm alignment, 2023.
- Speculative decoding: Lossless speedup of autoregressive translation, 2023. URL https://openreview.net/forum?id=H-VlwsYvVi.
- Smoothquant: Accurate and efficient post-training quantization for large language models. In International Conference on Machine Learning, pages 38087–38099. PMLR, 2023a.
- A survey on non-autoregressive generation for neural machine translation and beyond. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2023b.
- Do transformers really perform badly for graph representation? Advances in Neural Information Processing Systems, 34:28877–28888, 2021.
- Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
- H _2_2\_2_ 2 o: Heavy-hitter oracle for efficient generative inference of large language models. arXiv preprint arXiv:2306.14048, 2023.
- Judging llm-as-a-judge with mt-bench and chatbot arena, 2023.
- Distillspec: Improving speculative decoding via knowledge distillation. arXiv preprint arXiv: 2310.08461, 2023.