- The paper introduces Consistency Models, which bypass iterative diffusion sampling by learning a function that maps any point on the PF ODE trajectory directly back to the original data.
- It details two training methods—Consistency Distillation and Consistency Training—that enforce self-consistency using parameterized functions and EMA-based target networks.
- Practical applications include fast one-step generation and versatile image editing tasks, achieving state-of-the-art FID scores on datasets like CIFAR-10 and ImageNet.
This paper introduces Consistency Models (CMs) (2303.01469), a new class of generative models designed to address the slow iterative sampling process inherent in diffusion models while retaining many of their benefits. The core idea is to learn a function that directly maps points from any time step on a Probability Flow (PF) ODE trajectory back to the trajectory's origin (the data sample).
Core Concept: Consistency Function
- PF ODE: Diffusion models rely on a PF ODE (Eq. 2) that transforms data x0​ into noise xT​. The reverse process generates data from noise xT​ by solving the ODE backward.
- Consistency Function: Defined as f:(xt​,t)↦xϵ​, where xt​ is a point on the ODE trajectory at time t, and xϵ​ is the point near the origin (data point, typically at a small time ϵ>0).
- Self-Consistency Property: The defining characteristic is that for any two points (xt​,t) and (xt′​,t′) on the same ODE trajectory, the consistency function output is identical: xT​0.
- Consistency Model: A parameterized function xT​1 is trained to approximate the true consistency function xT​2 by enforcing this self-consistency property.
Implementation: Parameterization
A crucial aspect is enforcing the boundary condition xT​3. The paper proposes and uses a practical parameterization using skip connections:
ϵ>03
where xT​4 is a neural network (e.g., based on diffusion model architectures like U-Net), and xT​5, xT​6 are differentiable functions satisfying:
This structure ensures the boundary condition is met and allows leveraging existing diffusion model architectures. The paper uses modified versions of the scaling factors from EDM (Karras et al., 2022) to satisfy this for xT​9.
Implementation: Sampling
- One-Step Generation: Sample noise xT​0 and compute the data sample directly: xT​1. This is very fast, requiring only one network evaluation.
- Multi-Step Sampling (Algorithm 1): Improves sample quality by trading compute. It involves alternating denoising steps with the CM and adding noise:
- Select a time xT​6 (from a predefined sequence xT​7).
- Add noise: Sample xT​8, compute xT​9.
- Denoise: Compute f:(xt​,t)↦xϵ​0.
- 3. Output f:(xt​,t)↦xϵ​1.
- The sequence f:(xt​,t)↦xϵ​2 can be found using optimization methods like greedy ternary search to minimize FID.
Training Method 1: Consistency Distillation (CD)
This method trains a CM f:(xt​,t)↦xϵ​3 by distilling knowledge from a pre-trained diffusion (score) model f:(xt​,t)↦xϵ​4.
- Goal: Enforce f:(xt​,t)↦xϵ​5 for adjacent points on the empirical PF ODE trajectory defined by f:(xt​,t)↦xϵ​6.
- Process (Algorithm 2):
- Implementation Details:
- xt​3 is a target network, updated via Exponential Moving Average (EMA) of xt​4 (Eq. 8). Using
stop_gradient on the target network output is crucial for stability.
- xt​5 is a distance metric. LPIPS (Zhang et al., 2018) works best for images, outperforming L1 and L2.
- xt​6 is a weighting function (often set to 1).
- Higher-order ODE solvers (like Heun) generally perform better than lower-order ones (like Euler) for computing xt​7.
- The number of discretization intervals xt​8 needs tuning (e.g., xt​9 for CIFAR-10 with Heun).
Training Method 2: Consistency Training (CT)
This method trains a CM t0 from scratch, without requiring a pre-trained diffusion model. It makes CMs an independent class of generative models.
- Goal: Enforce t1, where t2.
- Process (Algorithm 3): Based on the theoretical result (Theorem 2) that the CD loss approximates the CT loss (Eq. 9) for small step sizes when using Euler solver implicitly.
- Implementation Details:
- Uses the same EMA target network t8 as CD.
- Crucially uses adaptive schedules for the number of time steps t9 and the EMA decay rate xϵ​0 (where xϵ​1 is the training step). xϵ​2 typically starts small and increases, while xϵ​3 starts high (e.g., 0.9) and approaches 1. This balances convergence speed and final quality. Appendix C provides specific schedule formulas.
- LPIPS is also effective here.
Practical Applications & Results
- Fast Generation: CMs achieve state-of-the-art FID scores for one-step and two-step generation on CIFAR-10 (3.55/2.93 FID) and ImageNet 64x64 (6.20/4.70 FID) when trained via CD, significantly outperforming Progressive Distillation (PD).
- Standalone Performance: When trained via CT, CMs outperform other one-step non-adversarial methods (VAEs, Flows) and achieve results comparable to PD without needing distillation.
- Zero-Shot Data Editing: CMs inherit the editing capabilities of diffusion models. Using variations of the multi-step sampling algorithm (Algorithm 4 in Appendix), they can perform:
- Inpainting: Mask unknown regions and iteratively refine using the CM.
- Colorization: Treat color channels as missing information in a transformed space (e.g., YUV or using an orthogonal basis).
- Super-Resolution: Treat high-frequency details as missing information in a transformed space (e.g., using patch averaging and orthogonal basis).
- Stroke-guided Editing (SDEdit): Use a stroke image as the starting point xϵ​4 in multi-step sampling.
- Denoising: Apply xϵ​5 directly to an image xϵ​6 with noise level xϵ​7.
- Interpolation: Interpolate between the initial noise vectors xϵ​8 (e.g., using spherical linear interpolation) and then apply xϵ​9.
Implementation Considerations
- Architecture: Can reuse U-Net architectures from diffusion models (e.g., NCSN++, ADM).
- Target Network: Using an EMA target network with
stop_gradient is vital for both CD and CT.
- Metric: LPIPS is highly recommended for image data.
- Schedules (CT): Carefully designed adaptive schedules for ϵ>00 and ϵ>01 are important for CT performance.
- Computational Cost: Training cost is comparable to training diffusion models. Inference is much faster (1 network evaluation for one-step, N evaluations for N-step).
Continuous-Time Extensions
The paper also derives continuous-time versions of the CD and CT losses (Appendix B), eliminating the need for discrete time steps ϵ>02. These objectives require calculating Jacobian-vector products, often necessitating forward-mode automatic differentiation, which might not be standard in all frameworks. Experimental results show they can work well, especially continuous-time CT, but may require careful initialization or variance reduction techniques.