Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel (2406.00924v2)
Abstract: Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works~\cite{chen2023sampling,chen2023ode,benton2023error,lee2022convergence} have proposed schemes for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee's randomized midpoint method for log-concave sampling~\cite{ShenL19}. We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance ($\widetilde O(d{5/12})$ compared to $\widetilde O(\sqrt{d})$ from prior work). We also show that our algorithm can be parallelized to run in only $\widetilde O(\log2 d)$ parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models. As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence $\widetilde O(d{5/12})$ compared to $\widetilde O(\sqrt{d})$ from prior work.
Collections
Sign up for free to add this paper to one or more collections.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.