Papers
Topics
Authors
Recent
Search
2000 character limit reached

Consistency Models

Published 2 Mar 2023 in cs.LG, cs.CV, and stat.ML | (2303.01469v2)

Abstract: Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.

Citations (645)

Summary

  • The paper introduces Consistency Models, which bypass iterative diffusion sampling by learning a function that maps any point on the PF ODE trajectory directly back to the original data.
  • It details two training methods—Consistency Distillation and Consistency Training—that enforce self-consistency using parameterized functions and EMA-based target networks.
  • Practical applications include fast one-step generation and versatile image editing tasks, achieving state-of-the-art FID scores on datasets like CIFAR-10 and ImageNet.

This paper introduces Consistency Models (CMs) (2303.01469), a new class of generative models designed to address the slow iterative sampling process inherent in diffusion models while retaining many of their benefits. The core idea is to learn a function that directly maps points from any time step on a Probability Flow (PF) ODE trajectory back to the trajectory's origin (the data sample).

Core Concept: Consistency Function

  1. PF ODE: Diffusion models rely on a PF ODE (Eq. 2) that transforms data x0x_0 into noise xTx_T. The reverse process generates data from noise xTx_T by solving the ODE backward.
  2. Consistency Function: Defined as f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon, where xtx_t is a point on the ODE trajectory at time tt, and xϵx_\epsilon is the point near the origin (data point, typically at a small time ϵ>0\epsilon > 0).
  3. Self-Consistency Property: The defining characteristic is that for any two points (xt,t)(x_t, t) and (xt′,t′)(x_{t'}, t') on the same ODE trajectory, the consistency function output is identical: xTx_T0.
  4. Consistency Model: A parameterized function xTx_T1 is trained to approximate the true consistency function xTx_T2 by enforcing this self-consistency property.

Implementation: Parameterization

A crucial aspect is enforcing the boundary condition xTx_T3. The paper proposes and uses a practical parameterization using skip connections: ϵ>0\epsilon > 03 where xTx_T4 is a neural network (e.g., based on diffusion model architectures like U-Net), and xTx_T5, xTx_T6 are differentiable functions satisfying:

  • xTx_T7
  • xTx_T8

This structure ensures the boundary condition is met and allows leveraging existing diffusion model architectures. The paper uses modified versions of the scaling factors from EDM (Karras et al., 2022) to satisfy this for xTx_T9.

Implementation: Sampling

  • One-Step Generation: Sample noise xTx_T0 and compute the data sample directly: xTx_T1. This is very fast, requiring only one network evaluation.
  • Multi-Step Sampling (Algorithm 1): Improves sample quality by trading compute. It involves alternating denoising steps with the CM and adding noise:
    • Select a time xTx_T6 (from a predefined sequence xTx_T7).
    • Add noise: Sample xTx_T8, compute xTx_T9.
    • Denoise: Compute f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon0.
    • 3. Output f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon1.
    • The sequence f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon2 can be found using optimization methods like greedy ternary search to minimize FID.

Training Method 1: Consistency Distillation (CD)

This method trains a CM f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon3 by distilling knowledge from a pre-trained diffusion (score) model f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon4.

  1. Goal: Enforce f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon5 for adjacent points on the empirical PF ODE trajectory defined by f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon6.
  2. Process (Algorithm 2):
    • Sample data f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon7.
    • Sample time index f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon8.
    • Generate noisy sample f:(xt,t)↦xϵf: (x_t, t) \mapsto x_\epsilon9.
    • Use one step of a numerical ODE solver (e.g., Heun) with the score model xtx_t0 to estimate the previous point: xtx_t1.
    • Minimize the consistency distillation loss (Eq. 7):

      xtx_t2

  3. Implementation Details:
    • xtx_t3 is a target network, updated via Exponential Moving Average (EMA) of xtx_t4 (Eq. 8). Using stop_gradient on the target network output is crucial for stability.
    • xtx_t5 is a distance metric. LPIPS (Zhang et al., 2018) works best for images, outperforming L1 and L2.
    • xtx_t6 is a weighting function (often set to 1).
    • Higher-order ODE solvers (like Heun) generally perform better than lower-order ones (like Euler) for computing xtx_t7.
    • The number of discretization intervals xtx_t8 needs tuning (e.g., xtx_t9 for CIFAR-10 with Heun).

Training Method 2: Consistency Training (CT)

This method trains a CM tt0 from scratch, without requiring a pre-trained diffusion model. It makes CMs an independent class of generative models.

  1. Goal: Enforce tt1, where tt2.
  2. Process (Algorithm 3): Based on the theoretical result (Theorem 2) that the CD loss approximates the CT loss (Eq. 9) for small step sizes when using Euler solver implicitly.
    • Sample data tt3.
    • Sample time index tt4 (where tt5 increases during training).
    • Sample noise tt6.
    • Minimize the consistency training loss:

      tt7

  3. Implementation Details:
    • Uses the same EMA target network tt8 as CD.
    • Crucially uses adaptive schedules for the number of time steps tt9 and the EMA decay rate xϵx_\epsilon0 (where xϵx_\epsilon1 is the training step). xϵx_\epsilon2 typically starts small and increases, while xϵx_\epsilon3 starts high (e.g., 0.9) and approaches 1. This balances convergence speed and final quality. Appendix C provides specific schedule formulas.
    • LPIPS is also effective here.

Practical Applications & Results

  • Fast Generation: CMs achieve state-of-the-art FID scores for one-step and two-step generation on CIFAR-10 (3.55/2.93 FID) and ImageNet 64x64 (6.20/4.70 FID) when trained via CD, significantly outperforming Progressive Distillation (PD).
  • Standalone Performance: When trained via CT, CMs outperform other one-step non-adversarial methods (VAEs, Flows) and achieve results comparable to PD without needing distillation.
  • Zero-Shot Data Editing: CMs inherit the editing capabilities of diffusion models. Using variations of the multi-step sampling algorithm (Algorithm 4 in Appendix), they can perform:
    • Inpainting: Mask unknown regions and iteratively refine using the CM.
    • Colorization: Treat color channels as missing information in a transformed space (e.g., YUV or using an orthogonal basis).
    • Super-Resolution: Treat high-frequency details as missing information in a transformed space (e.g., using patch averaging and orthogonal basis).
    • Stroke-guided Editing (SDEdit): Use a stroke image as the starting point xϵx_\epsilon4 in multi-step sampling.
    • Denoising: Apply xϵx_\epsilon5 directly to an image xϵx_\epsilon6 with noise level xϵx_\epsilon7.
    • Interpolation: Interpolate between the initial noise vectors xϵx_\epsilon8 (e.g., using spherical linear interpolation) and then apply xϵx_\epsilon9.

Implementation Considerations

  • Architecture: Can reuse U-Net architectures from diffusion models (e.g., NCSN++, ADM).
  • Target Network: Using an EMA target network with stop_gradient is vital for both CD and CT.
  • Metric: LPIPS is highly recommended for image data.
  • Schedules (CT): Carefully designed adaptive schedules for ϵ>0\epsilon > 00 and ϵ>0\epsilon > 01 are important for CT performance.
  • Computational Cost: Training cost is comparable to training diffusion models. Inference is much faster (1 network evaluation for one-step, N evaluations for N-step).

Continuous-Time Extensions

The paper also derives continuous-time versions of the CD and CT losses (Appendix B), eliminating the need for discrete time steps ϵ>0\epsilon > 02. These objectives require calculating Jacobian-vector products, often necessitating forward-mode automatic differentiation, which might not be standard in all frameworks. Experimental results show they can work well, especially continuous-time CT, but may require careful initialization or variance reduction techniques.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 21 tweets with 136 likes about this paper.