Emergent Mind

DisCo-Diff: Enhancing Continuous Diffusion Models with Discrete Latents

(2407.03300)
Published Jul 3, 2024 in cs.LG , cs.AI , and cs.CV

Abstract

Diffusion models (DMs) have revolutionized generative learning. They utilize a diffusion process to encode data into a simple Gaussian distribution. However, encoding a complex, potentially multimodal data distribution into a single continuous Gaussian distribution arguably represents an unnecessarily challenging learning problem. We propose Discrete-Continuous Latent Variable Diffusion Models (DisCo-Diff) to simplify this task by introducing complementary discrete latent variables. We augment DMs with learnable discrete latents, inferred with an encoder, and train DM and encoder end-to-end. DisCo-Diff does not rely on pre-trained networks, making the framework universally applicable. The discrete latents significantly simplify learning the DM's complex noise-to-data mapping by reducing the curvature of the DM's generative ODE. An additional autoregressive transformer models the distribution of the discrete latents, a simple step because DisCo-Diff requires only few discrete variables with small codebooks. We validate DisCo-Diff on toy data, several image synthesis tasks as well as molecular docking, and find that introducing discrete latents consistently improves model performance. For example, DisCo-Diff achieves state-of-the-art FID scores on class-conditioned ImageNet-64/128 datasets with ODE sampler.

Husky image modeling using Discrete-Continuous Latent Variable Diffusion Models with vision transformer encoder.

Overview

  • DisCo-Diff improves continuous diffusion models by integrating discrete latent variables, simplifying complex mappings and enhancing performance.

  • The framework is trained end-to-end, incorporating an autoregressive transformer to efficiently model discrete latents, achieving state-of-the-art results on ImageNet benchmarks and superior performance in molecular docking tasks.

  • DisCo-Diff's architecture includes a U-Net denoiser network, Vision Transformer encoder, and autoregressive transformer decoder, with theoretical and practical implications for broader applications in machine learning and healthcare.

DisCo-Diff: Enhancing Continuous Diffusion Models with Discrete Latents

The paper introduces Discrete-Continuous Latent Variable Diffusion Models (DisCo-Diff), a novel approach that seeks to enhance continuous diffusion models (DMs) by incorporating discrete latent variables. The authors argue that directly encoding complex, multimodal data distributions into a unimodal Gaussian distribution in continuous diffusion models represents an unnecessary and challenging learning problem. DisCo-Diff proposes a more efficient framework by augmenting DMs with discrete latent variables, inferred using an encoder trained end-to-end with the DM.

Key Contributions

  1. Introduction of Discrete Latents: The core innovation of DisCo-Diff is the augmentation of traditional DMs with discrete latent variables. This addition aims to simplify the complex noise-to-data mapping in DMs by reducing the curvature of the generative Ordinary Differential Equation (ODE). The discrete latents are inferred through an encoder and incorporated during the diffusion process, improving the model's performance significantly.
  2. End-to-End Training Without Pre-trained Networks: DisCo-Diff is trained end-to-end, including both the DM and the encoder. This is a significant advantage over previous approaches that rely on pre-trained networks, which can be domain-specific and not universally applicable. This approach also aligns the discrete latents with the DM's score matching objective, facilitating a more effective learning process.
  3. Autoregressive Modeling of Discrete Latents: The distribution of discrete latents is modeled using an autoregressive transformer. This method is computationally efficient because DisCo-Diff requires only a few discrete variables with small codebooks.
  4. Improved Performance Metrics: The paper presents strong numerical results, with DisCo-Diff achieving state-of-the-art Fréchet Inception Distance (FID) scores on class-conditioned ImageNet-64 and ImageNet-128 datasets using ODE samplers. For instance, on ImageNet-64, DisCo-Diff achieved an FID score of 1.65, significantly improving over previous methods.

Experimental Validation

The authors validate DisCo-Diff on several tasks, including image synthesis and molecular docking. In all these tasks, introducing discrete latents consistently improves the model's performance. Notably, DisCo-Diff achieved state-of-the-art results on ImageNet generation benchmarks and demonstrated improved performance in molecular docking, a critical task in drug discovery.

  1. Image Synthesis:

    • Class-Conditioned ImageNet-64 and ImageNet-128: DisCo-Diff outperforms strong baselines and sets new state-of-the-art FID scores. For example, it reduced the previous best FID score of 2.36 on ImageNet-64 to 1.65 using an ODE sampler.
    • Unconditional Synthesis: DisCo-Diff also excels in unconditional image synthesis tasks, outperforming existing methods and setting a new FID record of 1.22 for class-conditional ImageNet-64 using a Restart sampler.

    The results unequivocally demonstrate that discrete latents provide meaningful improvements in generative tasks by simplifying the complex noise-to-data mapping inherent to traditional DMs.

  2. Molecular Docking:

    • Performance Boost: DisCo-Diff was applied to molecular docking tasks, specifically for predicting the 3D binding structures in drug discovery. The experiments show that the inclusion of discrete latents enhances the performance of DMs, achieving a higher success rate in predicting accurate molecular docking poses compared to baseline methods.

Architectural Design

The architecture of DisCo-Diff includes several critical components:

  1. Denoiser Network: The denoiser network is a U-Net for images, augmented with cross-attention layers to incorporate discrete latents at multiple levels of resolution.
  2. Encoder: The encoder is a Vision Transformer (ViT), with additional classification tokens for each discrete latent, ensuring global features of the data are effectively captured.
  3. Autoregressive Transformer: A 12-layer transformer decoder models the distribution of the discrete latents, realized in a simple yet efficient manner due to the low-dimensional nature of these latents.

Additionally, the authors explore a hierarchical design where different discrete latents are fed into various levels of the U-Net architecture. This design approach encourages the different latents to encode diverse aspects of the data, such as shape and color, aligning with observations from the generative adversarial network literature.

Theoretical and Practical Implications

Theoretical Implications: The proposed model theoretically simplifies the complex mappings in traditional DMs by introducing discrete latents that capture global data patterns. This contributes to a lower curvature in the generative ODE, making it easier to train and align with the score matching objective.

Practical Implications: DisCo-Diff's universal framework has practical implications across multiple domains. By improving the fidelity and diversity of generated samples in image synthesis tasks and enhancing predictive capabilities in molecular docking, DisCo-Diff demonstrates its versatility and potential for broader applications in machine learning and healthcare.

Future Work

The authors suggest several future directions for DisCo-Diff:

  • Application to Other Generative Models: Extending the framework to other continuous flow models such as flow-matching or rectified flow.
  • Exploration in Other Domains: Applying DisCo-Diff to more varied data modalities, such as text-to-image generation and other scientific domains.
  • End-to-End Training: Investigating seamless end-to-end training procedures without the current two-stage process.

Conclusion

DisCo-Diff offers a promising advancement in the field of generative modeling by effectively combining discrete and continuous latents. The introduction of discrete latents not only simplifies the data generation process but also yields significant performance improvements across challenging tasks and datasets. The framework's universality, strong empirical results, and theoretical contributions mark a noteworthy step forward in the evolution of diffusion models.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.