Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
97 tokens/sec
GPT-4o
53 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

SpecTr: Fast Speculative Decoding via Optimal Transport (2310.15141v2)

Published 23 Oct 2023 in cs.LG, cs.CL, cs.DS, cs.IT, and math.IT

Abstract: Autoregressive sampling from LLMs has led to state-of-the-art results in several natural language tasks. However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up sampling is $\textit{speculative decoding}$: use a small model to sample a $\textit{draft}$ (block or sequence of tokens), and then score all tokens in the draft by the LLM in parallel. A subset of the tokens in the draft are accepted (and the rest rejected) based on a statistical method to guarantee that the final output follows the distribution of the large model. In this work, we provide a principled understanding of speculative decoding through the lens of optimal transport (OT) with $\textit{membership cost}$. This framework can be viewed as an extension of the well-known $\textit{maximal-coupling}$ problem. This new formulation enables us to generalize the speculative decoding method to allow for a set of $k$ candidates at the token-level, which leads to an improved optimal membership cost. We show that the optimal draft selection algorithm (transport plan) can be computed via linear programming, whose best-known runtime is exponential in $k$. We then propose a valid draft selection algorithm whose acceptance probability is $(1-1/e)$-optimal multiplicatively. Moreover, it can be computed in time almost linear with size of domain of a single token. Using this $new draft selection$ algorithm, we develop a new autoregressive sampling algorithm called $\textit{SpecTr}$, which provides speedup in decoding while ensuring that there is no quality degradation in the decoded output. We experimentally demonstrate that for state-of-the-art LLMs, the proposed approach achieves a wall clock speedup of 2.13X, a further 1.37X speedup over speculative decoding on standard benchmarks.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (29)
  1. A learning algorithm for boltzmann machines. Cognitive science, 9(1):147–169, 1985.
  2. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  3. One billion word benchmark for measuring progress in statistical language modeling. arXiv preprint arXiv:1312.3005, 2013.
  4. Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318, 2023.
  5. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
  6. Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In C.J. Burges, L. Bottou, M. Welling, Z. Ghahramani, and K.Q. Weinberger, editors, Advances in Neural Information Processing Systems, volume 26. Curran Associates, Inc., 2013.
  7. George B Dantzig. Linear programming. Operations research, 50(1):42–47, 2002.
  8. Frank Den Hollander. Probability theory: The coupling method. Lecture notes available online (http://websites. math. leidenuniv. nl/probability/lecturenotes/CouplingLectures. pdf), 2012.
  9. Hierarchical neural story generation. arXiv preprint arXiv:1805.04833, 2018.
  10. Controlling linguistic style aspects in neural language generation. arXiv preprint arXiv:1707.02633, 2017.
  11. Lossless acceleration for seq2seq generation with aggressive decoding. arXiv preprint arXiv:2205.10350, 2022.
  12. Google AI. Introducing PaLM 2, 2023. https://blog.google/technology/ai/google-palm-2-ai-large-language-model/.
  13. Google PaLM-2 Team. PaLM 2 technical report, 2023.
  14. Fast algorithms for computational optimal transport and wasserstein barycenter. In Silvia Chiappa and Roberto Calandra, editors, Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, volume 108 of Proceedings of Machine Learning Research, pages 2088–2097. PMLR, 26–28 Aug 2020.
  15. Flax: A neural network library and ecosystem for JAX, 2023.
  16. The curious case of neural text degeneration. arXiv preprint arXiv:1904.09751, 2019.
  17. Leonid V Kantorovich. On the translocation of masses. In Dokl. Akad. Nauk. USSR (NS), volume 37, pages 199–201, 1942.
  18. Path finding methods for linear programming: Solving linear programs in o (vrank) iterations and faster algorithms for maximum flow. In 2014 IEEE 55th Annual Symposium on Foundations of Computer Science, pages 424–433. IEEE, 2014.
  19. Fast inference from transformers via speculative decoding. In International Conference on Machine Learning, pages 19274–19286. PMLR, 2023.
  20. Eagle: Lossless acceleration of llm decoding by feature extrapolation, 2023. https://sites.google.com/corp/view/eagle-llm.
  21. Specinfer: Accelerating generative large language model serving with speculative inference and token tree verification, 2023.
  22. Fast and robust earth mover’s distances. In 2009 IEEE 12th international conference on computer vision, pages 460–467. IEEE, 2009.
  23. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  24. Blockwise parallel decoding for deep autoregressive models. Advances in Neural Information Processing Systems, 31, 2018.
  25. Efficient transformers: A survey. ACM Computing Surveys, 55(6):1–28, 2022.
  26. Lamda: Language models for dialog applications. arXiv preprint arXiv:2201.08239, 2022.
  27. LLaMA: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
  28. Cédric Villani et al. Optimal transport: old and new, volume 338. Springer, 2009.
  29. Inference with reference: Lossless acceleration of large language models. arXiv preprint arXiv:2304.04487, 2023.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (6)
  1. Ziteng Sun (29 papers)
  2. Ananda Theertha Suresh (73 papers)
  3. Jae Hun Ro (7 papers)
  4. Ahmad Beirami (86 papers)
  5. Himanshu Jain (19 papers)
  6. Felix Yu (62 papers)
Citations (50)

Summary

  • The paper introduces a novel framework for speculative decoding using optimal transport with a membership cost to maximize accepted tokens from multi-candidate drafts.
  • It derives an LP formulation and proposes a greedy algorithm that achieves a (1-1/e) approximation with near-linear time complexity.
  • Empirical results demonstrate up to 2.13X speedup over standard decoding while preserving target distribution consistency.

Speculative decoding accelerates autoregressive sampling from LLMs by using a smaller, faster draft model to propose sequences (drafts) of tokens, which are then verified in parallel by the larger target LLM. Only a prefix of the draft sequence statistically consistent with the target model's distribution is accepted. "SpecTr: Fast Speculative Decoding via Optimal Transport" (Sun et al., 2023 ) introduces a novel framework for speculative decoding based on Optimal Transport (OT) with a membership cost, generalizing existing methods and leading to improved performance.

Theoretical Framework: Optimal Transport with Membership Cost

Standard speculative decoding aims to sample from a target distribution p(xt+1x1:t)p(x_{t+1} | x_{1:t}) using proposals from a draft distribution q(xt+1x1:t)q(x'_{t+1} | x_{1:t}). A sequence xt+1,,xt+γx'_{t+1}, \dots, x'_{t+\gamma} is sampled from qq. Then, for each position i=1,,γi=1, \dots, \gamma, the token xt+ix'_{t+i} is accepted if a random variable UUniform(0,1)U \sim \text{Uniform}(0, 1) satisfies Up(xt+ix1:t+i1)q(xt+ix1:t+i1)U \le \frac{p(x'_{t+i} | x_{1:t+i-1})}{q(x'_{t+i} | x_{1:t+i-1})}. If rejected, a correction token is sampled from a modified distribution derived from pp and qq, and the process restarts.

SpecTr reformulates this process through the lens of optimal transport. The goal is to find a coupling (a joint probability distribution) π(x,y)\pi(x, y) between the draft distribution q(y)q(y) and the target distribution p(x)p(x) (conditioned on the history x1:tx_{1:t}, omitted for brevity) that maximizes the expected number of accepted tokens. This is framed as an OT problem with a specific "membership cost".

The key generalization introduced by SpecTr is allowing the draft model to propose a set of kk candidate tokens {y1,,yk}\{y_1, \dots, y_k\} at each position, instead of just one. Let QQ be the distribution over sets of kk candidates generated by the draft mechanism. The objective is to find a transport plan π(xY)\pi(x | Y), where Y={y1,,yk}QY = \{y_1, \dots, y_k\} \sim Q, that defines the probability of selecting the target token xx given the proposed set YY. This plan must satisfy the marginal constraint YQ(Y)π(xY)=p(x)\sum_{Y} Q(Y) \pi(x | Y) = p(x) for all xx, ensuring the final output follows the target distribution pp.

The goal is to maximize the probability that the selected token xx is one of the proposed candidates, i.e., maximize EYQ[xYπ(xY)]\mathbb{E}_{Y \sim Q} \left[ \sum_{x \in Y} \pi(x | Y) \right]. This corresponds to maximizing the expected number of accepted tokens from the draft set. This formulation extends the maximal-coupling problem by considering a set of candidates (k1k \ge 1) rather than just a single proposal.

Optimal Draft Selection and Computational Complexity

The paper demonstrates that the optimal transport plan π(xY)\pi^*(x|Y) maximizing the acceptance probability can be formulated and solved as a Linear Program (LP). The variables in the LP are π(xY)\pi(x|Y) for all possible target tokens xx and all possible candidate sets YY. The objective is to maximize YQ(Y)xYπ(xY)\sum_{Y} Q(Y) \sum_{x \in Y} \pi(x | Y) subject to the marginal constraint YQ(Y)π(xY)=p(x)\sum_{Y} Q(Y) \pi(x | Y) = p(x) for all xx, and the probability constraints π(xY)0\pi(x|Y) \ge 0 and xπ(xY)=1\sum_x \pi(x|Y) = 1 for all YY.

While theoretically solvable via LP, the number of possible candidate sets YY grows combinatorially with kk and the vocabulary size V|V|, making the LP formulation computationally intractable for practical values of kk and V|V|. The best-known runtime for solving such LPs is exponential in kk.

The SpecTr Algorithm and Approximation Guarantee

To overcome the computational bottleneck of solving the exact LP, SpecTr introduces an efficient, greedy algorithm for constructing a valid transport plan π(xY)\pi(x|Y) that achieves a strong approximation guarantee. This algorithm provides a (11/e)(1 - 1/e)-multiplicative approximation to the optimal acceptance probability achievable by any valid transport plan. Crucially, this approximate algorithm runs in time nearly linear in the vocabulary size V|V|, specifically O(VlogV)\mathcal{O}(|V| \log |V|) or potentially O(V)\mathcal{O}(|V|) with appropriate data structures, making it highly practical.

The core idea of the greedy algorithm resembles online bipartite matching or ad allocation algorithms. It iteratively assigns probability mass from the target distribution p(x)p(x) to candidate sets YY containing xx, prioritizing assignments that yield the highest "gain" in terms of acceptance probability, while carefully managing the remaining probability mass of p(x)p(x) and Q(Y)Q(Y) to ensure the marginal constraints are met.

The overall SpecTr sampling procedure at each step tt involves:

  1. Candidate Generation: Use the draft model qq to generate a set of kk candidate tokens Y={y1,,yk}Y = \{y_1, \dots, y_k\} for position t+1t+1. This might involve beam search or multiple sampling passes from qq.
  2. Target Probability Calculation: Query the target model pp in parallel to obtain the probabilities p(yix1:t)p(y_i | x_{1:t}) for all candidates yiYy_i \in Y. Note that obtaining the full distribution p(xx1:t)p(x | x_{1:t}) might be needed for the selection algorithm, depending on its specifics.
  3. Draft Selection: Apply the efficient (11/e)(1-1/e)-approximate algorithm using pp, qq, and the candidate set YY to determine the acceptance/rejection probabilities for the candidates. This yields a sampling distribution over Y{reject}Y \cup \{\text{reject}\}.
  4. Sampling: Sample an outcome according to the distribution computed in step 3.
    • If a token yiYy_i \in Y is sampled, it is accepted. Increment tt and repeat from step 1 for the next position.
    • If 'reject' is sampled, discard all candidates. Sample a correction token from the residual distribution required to maintain consistency with pp. Restart the process from step 1 for position t+1t+1.

This process allows multiple tokens (potentially the entire sequence proposed by the draft model, depending on acceptance) to be decoded per single invocation of the selection logic, amortizing the cost and leveraging the parallel computation capability for target model evaluations.

Implementation Considerations

Implementing SpecTr involves several key components and choices:

  • Draft Model Selection: The choice of the draft model qq is critical. It needs to be significantly faster than the target model pp but also provide distributions reasonably close to pp to ensure a high acceptance rate. Smaller versions of the target model or distilled models are common choices. The method used to generate the kk candidates (e.g., beam search with width kk, diverse beam search, multiple independent samples) also impacts performance.
  • Target Model Parallelism: The primary speedup comes from evaluating the target model pp on multiple candidate tokens yiy_i in parallel. Efficient batching and hardware utilization (GPUs/TPUs) are essential.
  • Parameter kk: Selecting the number of candidates kk involves a trade-off. Higher kk increases the theoretical upper bound on the acceptance rate via the (11/e)(1-1/e)-optimal algorithm, potentially leading to greater speedup. However, it also increases the cost of generating candidates (step 1) and potentially the complexity of the selection algorithm (step 3), though the proposed algorithm is efficient. The optimal kk is likely task- and model-dependent and requires empirical tuning. Values like k=4k=4 or k=8k=8 might be practical starting points.
  • Selection Algorithm Implementation: The efficient (11/e)(1-1/e)-approximate algorithm needs careful implementation. It likely involves sorting or using priority queues based on probability ratios like p(yi)/q(yi)p(y_i)/q(y_i) or similar metrics derived from the OT formulation. The exact details would require referring to the algorithm description in the paper (Sun et al., 2023 ). Ensuring numerical stability when dealing with small probabilities is important.
  • Distribution Handling: The algorithm requires access to probabilities p(yi)p(y_i | \dots) and potentially q(yi)q(y_i | \dots) or the distribution Q(Y)Q(Y). Efficiently querying these distributions and handling the residual distribution for rejection sampling are necessary.
  • System Integration: Integrating SpecTr into an LLM inference framework requires modifying the generation loop. Instead of sampling one token, it generates kk candidates, runs the parallel verification, executes the SpecTr selection logic, and handles acceptance/rejection/correction. This involves managing the state across multiple potential future tokens.

Below is a high-level conceptual pseudocode for the SpecTr sampling loop:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
function SpecTr_Sample(target_model_p, draft_model_q, k, prompt):
  tokens = prompt
  while not finished:
    # 1. Generate k candidates for the next position using draft model q
    candidate_set Y = Generate_Candidates(draft_model_q, tokens, k) # e.g., beam search

    # 2. Evaluate target probabilities for candidates in parallel
    target_probs_p = target_model_p.get_probs(tokens, Y) # { y_i: p(y_i|tokens) for y_i in Y }
    # Optional: May need draft probs q(y_i|tokens) or full p(x|tokens) depending on selection algorithm

    # 3. Run SpecTr selection algorithm
    # This function computes a distribution over Y U {reject}
    sampling_dist = SpecTr_Selection_Algorithm(target_probs_p, candidate_set Y, /* other required info like q */)

    # 4. Sample outcome
    outcome = Sample_From_Distribution(sampling_dist)

    if outcome == 'reject':
      # Sample correction token from residual distribution
      # Ensure overall distribution matches p
      correction_token = Sample_Correction_Token(target_model_p, draft_model_q, tokens, Y)
      append correction_token to tokens
    else: # outcome is some y_i in Y
      accepted_token = outcome
      append accepted_token to tokens

    # Check termination condition (e.g., max length, EOS token)
    if should_terminate(tokens):
      finished = True
  return tokens

Experimental Results

The paper reports significant empirical speedups. On standard benchmarks using state-of-the-art LLMs, SpecTr achieved a wall clock speedup of 2.13X compared to standard autoregressive decoding. Furthermore, it demonstrated a 1.37X speedup over the baseline speculative decoding method (which corresponds to the case k=1k=1). These results highlight the practical benefit of the OT formulation and the multi-candidate (k>1k>1) approach enabled by SpecTr. The authors emphasize that this speedup is achieved while mathematically guaranteeing that the output distribution remains identical to that of the target model pp, thus incurring no degradation in generation quality.

Conclusion

SpecTr provides a principled extension to speculative decoding by framing it as an optimal transport problem with membership cost. This allows generalizing the mechanism to handle multiple (kk) draft candidates per position. While the exact optimal solution is computationally expensive, SpecTr introduces an efficient approximation algorithm with a (11/e)(1-1/e)-optimality guarantee and near-linear time complexity. This leads to substantial wall-clock speedups in practice (reported as 2.13X overall, 1.37X over standard speculative decoding) without compromising the output quality, making it a promising technique for accelerating inference in LLMs. Implementing SpecTr requires careful integration with existing inference systems, focusing on efficient candidate generation, parallel target model evaluation, and the implementation of the novel selection algorithm.

X Twitter Logo Streamline Icon: https://streamlinehq.com