SpecTr: Fast Speculative Decoding via Optimal Transport
(2310.15141v2)
Published 23 Oct 2023 in cs.LG, cs.CL, cs.DS, cs.IT, and math.IT
Abstract: Autoregressive sampling from LLMs has led to state-of-the-art results in several natural language tasks. However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up sampling is $\textit{speculative decoding}$: use a small model to sample a $\textit{draft}$ (block or sequence of tokens), and then score all tokens in the draft by the LLM in parallel. A subset of the tokens in the draft are accepted (and the rest rejected) based on a statistical method to guarantee that the final output follows the distribution of the large model. In this work, we provide a principled understanding of speculative decoding through the lens of optimal transport (OT) with $\textit{membership cost}$. This framework can be viewed as an extension of the well-known $\textit{maximal-coupling}$ problem. This new formulation enables us to generalize the speculative decoding method to allow for a set of $k$ candidates at the token-level, which leads to an improved optimal membership cost. We show that the optimal draft selection algorithm (transport plan) can be computed via linear programming, whose best-known runtime is exponential in $k$. We then propose a valid draft selection algorithm whose acceptance probability is $(1-1/e)$-optimal multiplicatively. Moreover, it can be computed in time almost linear with size of domain of a single token. Using this $new draft selection$ algorithm, we develop a new autoregressive sampling algorithm called $\textit{SpecTr}$, which provides speedup in decoding while ensuring that there is no quality degradation in the decoded output. We experimentally demonstrate that for state-of-the-art LLMs, the proposed approach achieves a wall clock speedup of 2.13X, a further 1.37X speedup over speculative decoding on standard benchmarks.
The paper introduces a novel framework for speculative decoding using optimal transport with a membership cost to maximize accepted tokens from multi-candidate drafts.
It derives an LP formulation and proposes a greedy algorithm that achieves a (1-1/e) approximation with near-linear time complexity.
Empirical results demonstrate up to 2.13X speedup over standard decoding while preserving target distribution consistency.
Speculative decoding accelerates autoregressive sampling from LLMs by using a smaller, faster draft model to propose sequences (drafts) of tokens, which are then verified in parallel by the larger target LLM. Only a prefix of the draft sequence statistically consistent with the target model's distribution is accepted. "SpecTr: Fast Speculative Decoding via Optimal Transport" (Sun et al., 2023) introduces a novel framework for speculative decoding based on Optimal Transport (OT) with a membership cost, generalizing existing methods and leading to improved performance.
Theoretical Framework: Optimal Transport with Membership Cost
Standard speculative decoding aims to sample from a target distribution p(xt+1∣x1:t) using proposals from a draft distribution q(xt+1′∣x1:t). A sequence xt+1′,…,xt+γ′ is sampled from q. Then, for each position i=1,…,γ, the token xt+i′ is accepted if a random variable U∼Uniform(0,1) satisfies U≤q(xt+i′∣x1:t+i−1)p(xt+i′∣x1:t+i−1). If rejected, a correction token is sampled from a modified distribution derived from p and q, and the process restarts.
SpecTr reformulates this process through the lens of optimal transport. The goal is to find a coupling (a joint probability distribution) π(x,y) between the draft distribution q(y) and the target distribution p(x) (conditioned on the history x1:t, omitted for brevity) that maximizes the expected number of accepted tokens. This is framed as an OT problem with a specific "membership cost".
The key generalization introduced by SpecTr is allowing the draft model to propose a set of k candidate tokens {y1,…,yk} at each position, instead of just one. Let Q be the distribution over sets of k candidates generated by the draft mechanism. The objective is to find a transport plan π(x∣Y), where Y={y1,…,yk}∼Q, that defines the probability of selecting the target token x given the proposed set Y. This plan must satisfy the marginal constraint ∑YQ(Y)π(x∣Y)=p(x) for all x, ensuring the final output follows the target distribution p.
The goal is to maximize the probability that the selected token x is one of the proposed candidates, i.e., maximize EY∼Q[x∈Y∑π(x∣Y)]. This corresponds to maximizing the expected number of accepted tokens from the draft set. This formulation extends the maximal-coupling problem by considering a set of candidates (k≥1) rather than just a single proposal.
Optimal Draft Selection and Computational Complexity
The paper demonstrates that the optimal transport plan π∗(x∣Y) maximizing the acceptance probability can be formulated and solved as a Linear Program (LP). The variables in the LP are π(x∣Y) for all possible target tokens x and all possible candidate sets Y. The objective is to maximize ∑YQ(Y)∑x∈Yπ(x∣Y) subject to the marginal constraint ∑YQ(Y)π(x∣Y)=p(x) for all x, and the probability constraints π(x∣Y)≥0 and ∑xπ(x∣Y)=1 for all Y.
While theoretically solvable via LP, the number of possible candidate sets Y grows combinatorially with k and the vocabulary size ∣V∣, making the LP formulation computationally intractable for practical values of k and ∣V∣. The best-known runtime for solving such LPs is exponential in k.
The SpecTr Algorithm and Approximation Guarantee
To overcome the computational bottleneck of solving the exact LP, SpecTr introduces an efficient, greedy algorithm for constructing a valid transport plan π(x∣Y) that achieves a strong approximation guarantee. This algorithm provides a (1−1/e)-multiplicative approximation to the optimal acceptance probability achievable by any valid transport plan. Crucially, this approximate algorithm runs in time nearly linear in the vocabulary size ∣V∣, specifically O(∣V∣log∣V∣) or potentially O(∣V∣) with appropriate data structures, making it highly practical.
The core idea of the greedy algorithm resembles online bipartite matching or ad allocation algorithms. It iteratively assigns probability mass from the target distribution p(x) to candidate sets Y containing x, prioritizing assignments that yield the highest "gain" in terms of acceptance probability, while carefully managing the remaining probability mass of p(x) and Q(Y) to ensure the marginal constraints are met.
The overall SpecTr sampling procedure at each step t involves:
Candidate Generation: Use the draft model q to generate a set of k candidate tokens Y={y1,…,yk} for position t+1. This might involve beam search or multiple sampling passes from q.
Target Probability Calculation: Query the target model p in parallel to obtain the probabilities p(yi∣x1:t) for all candidates yi∈Y. Note that obtaining the full distribution p(x∣x1:t) might be needed for the selection algorithm, depending on its specifics.
Draft Selection: Apply the efficient (1−1/e)-approximate algorithm using p, q, and the candidate set Y to determine the acceptance/rejection probabilities for the candidates. This yields a sampling distribution over Y∪{reject}.
Sampling: Sample an outcome according to the distribution computed in step 3.
If a token yi∈Y is sampled, it is accepted. Increment t and repeat from step 1 for the next position.
If 'reject' is sampled, discard all candidates. Sample a correction token from the residual distribution required to maintain consistency with p. Restart the process from step 1 for position t+1.
This process allows multiple tokens (potentially the entire sequence proposed by the draft model, depending on acceptance) to be decoded per single invocation of the selection logic, amortizing the cost and leveraging the parallel computation capability for target model evaluations.
Implementation Considerations
Implementing SpecTr involves several key components and choices:
Draft Model Selection: The choice of the draft model q is critical. It needs to be significantly faster than the target model p but also provide distributions reasonably close to p to ensure a high acceptance rate. Smaller versions of the target model or distilled models are common choices. The method used to generate the k candidates (e.g., beam search with width k, diverse beam search, multiple independent samples) also impacts performance.
Target Model Parallelism: The primary speedup comes from evaluating the target model p on multiple candidate tokens yi in parallel. Efficient batching and hardware utilization (GPUs/TPUs) are essential.
Parameter k: Selecting the number of candidates k involves a trade-off. Higher k increases the theoretical upper bound on the acceptance rate via the (1−1/e)-optimal algorithm, potentially leading to greater speedup. However, it also increases the cost of generating candidates (step 1) and potentially the complexity of the selection algorithm (step 3), though the proposed algorithm is efficient. The optimal k is likely task- and model-dependent and requires empirical tuning. Values like k=4 or k=8 might be practical starting points.
Selection Algorithm Implementation: The efficient (1−1/e)-approximate algorithm needs careful implementation. It likely involves sorting or using priority queues based on probability ratios like p(yi)/q(yi) or similar metrics derived from the OT formulation. The exact details would require referring to the algorithm description in the paper (Sun et al., 2023). Ensuring numerical stability when dealing with small probabilities is important.
Distribution Handling: The algorithm requires access to probabilities p(yi∣…) and potentially q(yi∣…) or the distribution Q(Y). Efficiently querying these distributions and handling the residual distribution for rejection sampling are necessary.
System Integration: Integrating SpecTr into an LLM inference framework requires modifying the generation loop. Instead of sampling one token, it generates k candidates, runs the parallel verification, executes the SpecTr selection logic, and handles acceptance/rejection/correction. This involves managing the state across multiple potential future tokens.
Below is a high-level conceptual pseudocode for the SpecTr sampling loop:
function SpecTr_Sample(target_model_p, draft_model_q, k, prompt):
tokens = prompt
while not finished:
# 1. Generate k candidates for the next position using draft model q
candidate_set Y = Generate_Candidates(draft_model_q, tokens, k) # e.g., beam search
# 2. Evaluate target probabilities for candidates in parallel
target_probs_p = target_model_p.get_probs(tokens, Y) # { y_i: p(y_i|tokens) for y_i in Y }
# Optional: May need draft probs q(y_i|tokens) or full p(x|tokens) depending on selection algorithm
# 3. Run SpecTr selection algorithm
# This function computes a distribution over Y U {reject}
sampling_dist = SpecTr_Selection_Algorithm(target_probs_p, candidate_set Y, /* other required info like q */)
# 4. Sample outcome
outcome = Sample_From_Distribution(sampling_dist)
if outcome == 'reject':
# Sample correction token from residual distribution
# Ensure overall distribution matches p
correction_token = Sample_Correction_Token(target_model_p, draft_model_q, tokens, Y)
append correction_token to tokens
else: # outcome is some y_i in Y
accepted_token = outcome
append accepted_token to tokens
# Check termination condition (e.g., max length, EOS token)
if should_terminate(tokens):
finished = True
return tokens
Experimental Results
The paper reports significant empirical speedups. On standard benchmarks using state-of-the-art LLMs, SpecTr achieved a wall clock speedup of 2.13X compared to standard autoregressive decoding. Furthermore, it demonstrated a 1.37X speedup over the baseline speculative decoding method (which corresponds to the case k=1). These results highlight the practical benefit of the OT formulation and the multi-candidate (k>1) approach enabled by SpecTr. The authors emphasize that this speedup is achieved while mathematically guaranteeing that the output distribution remains identical to that of the target model p, thus incurring no degradation in generation quality.
Conclusion
SpecTr provides a principled extension to speculative decoding by framing it as an optimal transport problem with membership cost. This allows generalizing the mechanism to handle multiple (k) draft candidates per position. While the exact optimal solution is computationally expensive, SpecTr introduces an efficient approximation algorithm with a (1−1/e)-optimality guarantee and near-linear time complexity. This leads to substantial wall-clock speedups in practice (reported as 2.13X overall, 1.37X over standard speculative decoding) without compromising the output quality, making it a promising technique for accelerating inference in LLMs. Implementing SpecTr requires careful integration with existing inference systems, focusing on efficient candidate generation, parallel target model evaluation, and the implementation of the novel selection algorithm.