Emergent Mind

Tandem Transformers for Inference Efficient LLMs

(2402.08644)
Published Feb 13, 2024 in cs.AI and cs.CL

Abstract

The autoregressive nature of conventional LLMs inherently limits inference speed, as tokens are generated sequentially. While speculative and parallel decoding techniques attempt to mitigate this, they face limitations: either relying on less accurate smaller models for generation or failing to fully leverage the base LLM's representations. We introduce a novel architecture, Tandem transformers, to address these issues. This architecture uniquely combines (1) a small autoregressive model and (2) a large model operating in block mode (processing multiple tokens simultaneously). The small model's predictive accuracy is substantially enhanced by granting it attention to the large model's richer representations. On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko demonstrates a 3.3% improvement in next-token prediction accuracy over a standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter model with comparable downstream performance. We further incorporate the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model. This ensures that the Tandem of PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream task accuracy.

Visualization of Tandem transformers' process for inferring photoresponse dynamics in nano-engineering applications.

Overview

  • Tandem Transformers introduce an architecture to separate the natural language understanding (NLU) and natural language generation (NLG) tasks in LLMs to improve inference efficiency without losing accuracy.

  • The Tandem architecture combines a smaller autoregressive model with a larger block-mode model, allowing for simultaneous token processing and enhanced predictive accuracy.

  • Empirical evaluations using PaLM2 models show a 3.3% improvement in next-token prediction accuracy and a 1.16x speedup in inference compared to conventional models.

  • The research suggests a shift in LLM design towards more efficient inference, theoretically decoupling NLU and NLG tasks, and opens up future research paths for further optimization.

Enhancing LLMs with Tandem Transformers for Efficient Inference

Introduction to Tandem Transformers

LLMs have seen remarkable advances, yet their deployment in real-world applications is often constrained by their significant computation costs. The conventional autoregressive nature of LLMs, necessitating sequential token generation, substantially limits inference speed. This limitation restricts the full utilization of machine learning accelerators optimized for parallel computations, as these models primarily perform matrix-vector operations. Tackling this challenge, the paper introduces an innovative architecture, the Tandem Transformers, aiming to disaggregate the natural language understanding (NLU) and natural language generation (NLG) capacities of LLMs to achieve efficient inference without compromising accuracy.

Architectural Innovation

The novel Tandem Transformers architecture cleverly combines a smaller autoregressive model and a larger model operating in block mode. This architecture allows simultaneous processing of multiple tokens, enhancing the smaller model's predictive accuracy by leveraging richer representations from the large model. A key innovation is the distinction made between the capacity required for understanding the input prompt (NLU) and generating a response (NLG), which current decoder-only LLM architectures closely couple. The Tandem Transformers aim to allocate more capacity to NLU than NLG, investigating if high-quality response generation can be maintained under this design.

Empirical Evaluation and Results

The paper presents an empirical evaluation of the Tandem architecture using PaLM2-Bison and PaLM2-Gecko models. The results indicate a 3.3% improvement in next-token prediction accuracy for the Tandem model over a standalone PaLM2-Gecko, alongside a 1.16x speedup compared to a PaLM2-Otter model with comparable downstream performance. This substantial efficiency gain is further demonstrated in the Tandem model's application within the Speculative Decoding (SPEED) framework, achieving around a 1.14× faster inference than vanilla PaLM2-Gecko while maintaining identical downstream task accuracy.

Theoretical and Practical Implications

The research introduces a significant shift towards more inference-efficient LLMs by proposing a split in model capacity between NLU and NLG tasks. Theoretically, it challenges the traditional design of LLM architectures by providing evidence that the tasks of understanding prompts and generating responses can be effectively decoupled without loss of performance. Practically, this implies that LLMs can be optimized for faster responses in real-time applications, potentially expanding their usability in time-sensitive or resource-constrained environments.

Future Directions

While the Tandem Transformers architecture represents a significant advancement, the paper posits several avenues for future research. These include exploring other variants of tandem architectures, investigating alternatives to Low-rank Adaptation (LoRA) for fine-tuning, and further optimizing the tandem architecture with techniques like adaptive block length for larger batch sizes or multiple samples. Moreover, the potential for utilizing even smaller models within the SPEED framework for further efficiency gains opens up intriguing possibilities for LLM optimization.

Conclusion

The development of Tandem Transformers marks a noteworthy step towards addressing the efficiency challenges faced by current LLMs in deployment scenarios. By enabling LLMs to process input and generate responses more efficiently without sacrificing accuracy, this architecture paves the way for broader adoption and application of LLMs across various domains. As the field continues to evolve, the principles outlined in this work will undoubtedly contribute to the ongoing dialogue on optimizing LLMs for practical, real-world use cases.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube