Emergent Mind

SparQ Attention: Bandwidth-Efficient LLM Inference

(2312.04985)
Published Dec 8, 2023 in cs.LG

Abstract

The computational difficulties of LLM inference remain a significant obstacle to their widespread deployment. The need for many applications to support long input sequences and process them in large batches typically causes token-generation to be bottlenecked by data-transfer. For this reason, we introduce SparQ Attention, a technique for increasing the inference throughput of LLMs by utilising memory bandwidth more efficiently within the attention layers, through selective fetching of the cached history. Our proposed technique can be applied directly to off-the-shelf LLMs during inference, without requiring any modification to the pre-training setup or additional fine-tuning. We show that SparQ Attention brings up to 8x savings in attention data-transfers without substantial drops in accuracy, by evaluating Llama 2, Mistral and Pythia models on a wide range of downstream tasks.

Overview

  • Transformer models benefit from pre-training on large text corpora and are powerful but suffer from high computation demands during inference.

  • SparQ Attention is introduced to enhance inference efficiency by selectively fetching relevant portions of cached history within the attention mechanism.

  • The SparQ Attention algorithm significantly reduces memory bandwidth demand by using a three-step process to optimize attention, with no loss in model accuracy.

  • Performance tests on tasks like question answering and language modeling with models up to a billion parameters show robust results, with bandwidth reductions between 2× and 8×.

  • While fitting within ongoing research to make LLMs more efficient, SparQ Attention uniquely applies during inference and will be subject to future research to expand its capabilities.

Introduction

Transformer models have become increasingly effective at solving complex language processing tasks by pre-training on extensive text corpora. LLMs benefit from this approach, offering versatile capabilities across various text-based applications. However, an obstacle often encountered with LLMs is their high computational demand during inference. This limitation is particularly pronounced when processing a large number of samples with extended contexts, leading to significant memory and bandwidth requirements. Addressing this challenge, a new technique known as SparQ Attention has been introduced, aiming to enhance inference efficiency by selectively fetching relevant portions of cached history within the attention mechanism.

SparQ Attention Algorithm

SparQ Attention improves the efficiency of LLMs during inference by optimizing the attention mechanism's memory bandwidth usage. The technique works through three sequential steps. Initially, it locates the most significant components of the incoming query vector, then approximates initial attention scores using these components. Next, it captures the full key and value vectors for top-scoring tokens only. The final step amalgamates the results from the previous steps, interpolating the top scores with a running mean of the value vectors. Notably, this approach can reduce memory bandwidth demand up to eight times without loss in accuracy, and it can be directly applied to existing LLMs without altering pre-training setups or requiring additional fine-tuning.

Experiment and Results

The practical efficacy of SparQ Attention was tested across a variety of downstream tasks, including question answering, summarization, language modeling, and textual repetition. These tasks were designed to assess model performance in the presence of reduced data transfers and to pit SparQ Attention against other sparse attention methodologies. Exemplary performance was demonstrated using models such as Llama 2 and Pythia, with up to a billion parameters, across tasks that required long-sequence context processing. The technique was found to be robust, with bandwidth compression ratios ranging from 2× to 8×, often with negligible degradation in task performance.

Discussion and Related Work

This paper fits within a broader context of research that strives to improve the efficiency of attention mechanisms, including work on sparse attention and reduction of memory footprint in LLMs. Previous studies have introduced various models designed to improve efficiency, but many require modifications during the pre-training phase and may have trade-offs in task performance. In contrast, SparQ Attention stands out as it can be applied during inference to models without adjustments to their pre-trained weights. Despite showing substantial improvements in memory bandwidth reduction, SparQ Attention has its limitations; it conservatively manages memory and may have unexplored effects when combined with different transformer model variants. Future research may extend its applicability or overcome these limitations, potentially augmenting its role in efficient LLM inference further.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube