Emergent Mind

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

(2401.10774)
Published Jan 19, 2024 in cs.LG and cs.CL

Abstract

The inference process in LLMs is often limited due to the absence of parallelism in the auto-regressive decoding process, resulting in most operations being restricted by the memory bandwidth of accelerators. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present Medusa, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, Medusa constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, Medusa introduces only minimal overhead in terms of single-step latency while substantially reducing the number of decoding steps required. We present two levels of fine-tuning procedures for Medusa to meet the needs of different use cases: Medusa-1: Medusa is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. Medusa-2: Medusa is fine-tuned together with the backbone LLM, enabling better prediction accuracy of Medusa heads and higher speedup but needing a special training recipe that preserves the backbone model's capabilities. Moreover, we propose several extensions that improve or expand the utility of Medusa, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate Medusa on models of various sizes and training procedures. Our experiments demonstrate that Medusa-1 can achieve over 2.2x speedup without compromising generation quality, while Medusa-2 further improves the speedup to 2.3-3.6x.

Medusa-2 significantly enhances model speed, with self-distillation models showing less improvement due to quality-speed trade-offs.

Overview

  • The Medusa framework introduces multiple decoding heads for LLMs to parallelize token predictions and increase inference speed.

  • Medusa-1 involves keeping the backbone LLM static while adding predictive heads, whereas Medusa-2 includes fine-tuning of both the backbone and the added heads.

  • To overcome obstacles such as a lack of training data, the framework utilizes a 'self-distillation protocol' and a 'typical acceptance scheme', improving the plausibility of predictions.

  • Experimental results show that Medusa-1 achieves over 2.2 times speedup in inference without quality loss, and Medusa-2 provides even greater speed improvements.

  • Medusa's research code has been released for public collaboration, potentially enhancing LLM inference acceleration in various applications.

Introduction

Leverage in the computational power and memory of contemporary accelerators has hit a plateau when it comes to LLMs. The sequential nature of the auto-regressive decoding process in LLMs causes this bottleneck, which underutilizes the available computing capabilities of these technological workhorses. Speculative decoding has been introduced to address these inefficiencies. However, a significant roadblock has been the difficulties in deploying draft models that predict a sequence of tokens, which the larger LLMs then refine. This scenario is exactly where the Medusa framework comes into play, offering a straightforward solution to the intricate challenge of accelerating LLM inference.

Medusa Framework

The primary innovation introduced with Medusa is the addition of multiple decoding heads to the backbone LLM, which enables the prediction of multiple subsequent tokens in a parallel fashion. These heads are designed to be fine-tuned, ensuring they are closely aligned with the parent LLM in their predictions. Two distinct procedures have been outlined for integrating these predictive heads: Medusa-1 and Medusa-2. Medusa-1 pertains to a setting where the backbone LLM remains frozen during training, thus ensuring no alteration to its core capabilities while accelerating inference speed. Medusa-2 involves a more resource-intensive fine-tuning where the additional heads are trained together with the backbone LLM, potentially achieving even higher efficiency gains.

Addressing Challenges with Extensions

Several obstacles could impede the Medusa framework's widescale adoption, such as situations lacking sufficient training data. To tackle this, the researchers have designed a self-distillation protocol, which cleverly uses the LLM to generate training data for the Medusa heads. They have also introduced a 'typical acceptance scheme' as an alternative to rejection sampling, used in speculative decoding, to select the most plausible predictions from the Medusa heads. This approach maintains the quality of generation while potentially increasing the rate at which tokens can be accepted during the decoding process.

Experimental Results

In their comprehensive experiments, the researchers assessed Medusa on various model sizes and configurations. The findings are significant – Medusa-1 achieves more than a 2.2 times speedup in LLM inference with no loss in quality, whereas Medusa-2 pushes this further, attaining speed improvements ranging from 2.3 to 3.6 times. Moreover, another key takeaway is that their method can scale across different models and is particularly adept in scenarios with a batch size of one, which happens to represent the use case of hosting LLMs locally for personal applications.

Conclusion

Medusa has set a new precedent for inference acceleration in LLMs without compromising generation quality. Its versatile training approaches cater to diverse computational resource scenarios, and the proposed extensions effectively confront common problems when employing accelerated inference methods. The code for Medusa has been made available to the public, inviting collaborative efforts to further refine and incorporate the framework into different serving systems.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube