Emergent Mind

Simplified and Generalized Masked Diffusion for Discrete Data

(2406.04329)
Published Jun 6, 2024 in cs.LG and stat.ML

Abstract

Masked (or absorbing) diffusion is actively explored as an alternative to autoregressive models for generative modeling of discrete data. However, existing work in this area has been hindered by unnecessarily complex model formulations and unclear relationships between different perspectives, leading to suboptimal parameterization, training objectives, and ad hoc adjustments to counteract these issues. In this work, we aim to provide a simple and general framework that unlocks the full potential of masked diffusion models. We show that the continuous-time variational objective of masked diffusion models is a simple weighted integral of cross-entropy losses. Our framework also enables training generalized masked diffusion models with state-dependent masking schedules. When evaluated by perplexity, our models trained on OpenWebText surpass prior diffusion language models at GPT-2 scale and demonstrate superior performance on 4 out of 5 zero-shot language modeling tasks. Furthermore, our models vastly outperform previous discrete diffusion models on pixel-level image modeling, achieving 2.78~(CIFAR-10) and 3.42 (ImageNet 64$\times$64) bits per dimension that are comparable or better than autoregressive models of similar sizes.

Iterative token unmasking with color-coded steps in a 1024-token sequence generated by MD4.

Overview

  • The paper introduces a unified framework for masked diffusion models, simplifying the existing formulations and enhancing model performance for generative modeling of discrete data.

  • A key innovation is the state-dependent masking schedules, which allow flexible prioritization of token masking and unmasking, boosting the model's effectiveness.

  • Empirical evaluations show that the proposed models outperform previous diffusion models on text and image datasets, achieving superior results and demonstrating high-quality synthesis.

Simplified and Generalized Masked Diffusion for Discrete Data

The paper "Simplified and Generalized Masked Diffusion for Discrete Data" by Jiaxin Shi, Kehang Han, Zhe Wang, Arnaud Doucet, and Michalis K. Titsias introduces a unified framework for masked (or absorbing) diffusion models aimed at generative modeling of discrete data. The authors address several limitations in existing masked diffusion models, such as complex formulations and suboptimal training objectives, and present a more straightforward approach that improves performance significantly.

Key Contributions

  1. Simplified Theoretical Framework: The authors show that the continuous-time variational objective for masked diffusion models can be represented as a simple weighted integral of cross-entropy losses. This formulation unifies various approaches from the literature and clarifies the relationships between them.

  2. Generalized Model with State-Dependent Masking Schedules: The paper extends the standard formulation by incorporating state-dependent masking schedules, allowing the model to prioritize the masking and unmasking of specific tokens based on their states. This generalization enhances the model's flexibility and performance.

  3. Improved Parameterization and Training Objectives: Leveraging a prediction model for the mean (mean-parameterization) of the clean data given the masked data, the authors argue that this achieves more stable and effective training compared to score-based parameterizations used in prior work. The resulting models outperform previous diffusion models on standard benchmarks.

  4. Empirical Results: The models trained using the proposed framework achieve superior likelihood and zero-shot transfer performance on text modeling tasks. Specifically, the models exhibit better perplexity on OpenWebText and strong performance on several zero-shot language modeling tasks compared to existing diffusion models and GPT-2.

  5. Application to Image Data: The paper demonstrates the efficacy of the proposed framework in pixel-level image modeling tasks on datasets like CIFAR-10 and Downsampled ImageNet 64x64. The new models significantly outperform existing discrete diffusion models and match or exceed the performance of autoregressive models of similar size.

Experimental Evaluation

Text Modeling

For text modeling, the authors train their models on OpenWebText and evaluate them on tasks such as LAMBADA, WikiText2, and Penn Treebank. The results show that their masked diffusion models (referred to as MD4 and GenMD4) outperform previous methods like D3PM and SEDD Absorb in terms of zero-shot perplexity. The models also demonstrate faster convergence and better final likelihoods on the validation set.

On the text8 dataset, the MD4 and GenMD4 models achieve lower bits-per-character (BPC) than previous state-of-the-art diffusion models and any-order autoregressive models. The GenMD4 model further improves BPC, showcasing the benefits of state-dependent masking schedules.

Image Modeling

In pixel-level image modeling, MD4 sets a new state-of-the-art for discrete diffusion models on CIFAR-10 and matches the performance of autoregressive models on ImageNet 64x64. The paper includes several samples generated by MD4, demonstrating high-quality image synthesis despite modeling pixels as discrete tokens.

Theoretical Insights

The authors provide several theoretical results that enhance understanding and training of masked diffusion models. They derive the continuous-time limit of the Evidence Lower Bound (ELBO) for masked diffusion models and show its invariance properties concerning noise schedules. They also establish connections to existing work in continuous-time Markov chains and alternative parameterization approaches.

Future Directions

The paper concludes by suggesting future research directions, including the development of better architectures for discrete diffusion models and more robust state-dependent masking schedules. Additionally, the authors highlight the potential for extending their framework to other domains beyond text and image data.

Conclusion

The proposed framework for simplified and generalized masked diffusion models represents a significant advancement in generative modeling of discrete data. By addressing complexities in existing models and introducing state-dependent masking schedules, the authors achieve substantial improvements in both theoretical formulation and empirical performance. This work provides a solid foundation for future research in discrete diffusion models and their applications.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube