Emergent Mind

Better & Faster Large Language Models via Multi-token Prediction

(2404.19737)
Published Apr 30, 2024 in cs.CL

Abstract

Large language models such as GPT and Llama are trained with a next-token prediction loss. In this work, we suggest that training language models to predict multiple future tokens at once results in higher sample efficiency. More specifically, at each position in the training corpus, we ask the model to predict the following n tokens using n independent output heads, operating on top of a shared model trunk. Considering multi-token prediction as an auxiliary training task, we measure improved downstream capabilities with no overhead in training time for both code and natural language models. The method is increasingly useful for larger model sizes, and keeps its appeal when training for multiple epochs. Gains are especially pronounced on generative benchmarks like coding, where our models consistently outperform strong baselines by several percentage points. Our 13B parameter models solves 12 % more problems on HumanEval and 17 % more on MBPP than comparable next-token models. Experiments on small algorithmic tasks demonstrate that multi-token prediction is favorable for the development of induction heads and algorithmic reasoning capabilities. As an additional benefit, models trained with 4-token prediction are up to 3 times faster at inference, even with large batch sizes.

Multi-token prediction method and its improvement on the MBPP code task accuracy.

Overview

  • The paper introduces a novel training method for LLMs known as multi-token prediction, where models predict several tokens ahead in a sequence, enhancing both efficiency and contextual depth.

  • Significant performance enhancements were observed with this method, especially in larger models that benefited from predicting further into token sequences, leading to up to 3 times faster inference speeds and improved accuracy on various benchmarks.

  • The potential applications of multi-token prediction in practical settings include improving code generation and complex reasoning in AI, and this method could revolutionize traditional training approaches by focusing on advanced, human-like understanding.

Enhancing LLMs with Multi-Token Prediction

Introduction to Multi-Token Prediction

In traditional training of LLMs, such as GPT and Llama, each model predicts the next token based on the previous tokens in a sequence. This process, often efficient at a superficial level, holds substantial drawbacks primarily because it tends to focus on short-term accuracy rather than understanding deeper, contextual dependencies within the text.

The reviewed study proposes an intriguing twist to this method—training LLMs to predict multiple tokens ahead in one go, instead of just the next token. This approach, called multi-token prediction, has shown to yield remarkable benefits in terms of both computational efficiency and model performance, especially as the model scales up.

Key Findings and Contributions

The study outlines several compelling points:

  • Multi-token Architecture: A novel architecture is introduced which involves a shared trunk and multiple output heads that allow simultaneous prediction of multiple future tokens without additional training time or memory overhead.
  • Enhanced Performance at Scale: The performance improvements are more pronounced as model sizes increase, with larger models benefiting significantly from being able to look further ahead into the future token sequence.
  • Inference Efficiency: Aside from improving prediction accuracy during training, this method allows for faster model inference. Specifically, models trained with 4-token prediction achieved speeds up to 3 times faster during inference, attributed to their ability to utilize multi-token heads effectively.
  • Robust Empirical Validation: The models were rigorously tested across various codes and natural language benchmarks, demonstrating consistent superiority over standard single-token prediction models. For example, 13B parameter models solved 12% more problems on the HumanEval benchmark and 17% more on the MBPP benchmark compared to their single-token counterparts.

Practical Implications

  • For LLM Practitioners: This research can be directly applied to improve the efficiency and efficacy of model training and inference, making it a viable option for those looking to enhance the performance of their language models, especially for tasks involving code generation and complex reasoning.
  • For Research and Development: The simplicity of the multi-token prediction method coupled with its significant benefits suggests a new avenue for training LLMs that researchers can explore further. This could lead to the development of more advanced models that can better mimic human-like understanding and reasoning in their tasks.

Future Outlook

The introduction of multi-token prediction opens several questions and opportunities for future work. Adjusting the number of tokens predicted in accordance to the model size or task complexity intuitively seems beneficial, but finding optimal configurations automatically remains a challenge. Furthermore, integrating this method into broader NLP frameworks and exploring its full potential across various domains and languages could vastly improve our current capabilities in AI and machine learning.

Conclusion

The ability to predict multiple future tokens simultaneously represents a significant step forward in the training of LLMs. It not only helps in understanding longer contextual dependencies but also improves sample efficiency and inference speed, proving its value across various benchmarks. As we continue to push the boundaries of what AI can achieve, methodologies like multi-token prediction ensure that our models not only become faster and smarter but do so in a way that is resource-aware, harnessing the true potential of AI more effectively.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube