Papers
Topics
Authors
Recent
2000 character limit reached

Dataset Distillation by Matching Training Trajectories (2203.11932v1)

Published 22 Mar 2022 in cs.CV, cs.AI, and cs.LG

Abstract: Dataset distillation is the task of synthesizing a small dataset such that a model trained on the synthetic set will match the test accuracy of the model trained on the full dataset. In this paper, we propose a new formulation that optimizes our distilled data to guide networks to a similar state as those trained on real data across many training steps. Given a network, we train it for several iterations on our distilled data and optimize the distilled data with respect to the distance between the synthetically trained parameters and the parameters trained on real data. To efficiently obtain the initial and target network parameters for large-scale datasets, we pre-compute and store training trajectories of expert networks trained on the real dataset. Our method handily outperforms existing methods and also allows us to distill higher-resolution visual data.

Citations (308)

Summary

  • The paper's main contribution is a novel technique that matches synthetic data-induced parameter trajectories with expert trajectories to replicate full-data performance.
  • The methodology leverages backpropagation through multiple training updates to align long-range dynamics, markedly improving accuracy on CIFAR-10, CIFAR-100, and Tiny ImageNet.
  • Empirical results demonstrate significant accuracy gains, establishing state-of-the-art performance and enabling the distillation of high-resolution synthetic datasets.

Dataset Distillation by Matching Training Trajectories

Abstract

The paper "Dataset Distillation by Matching Training Trajectories" (2203.11932) introduces a novel approach to dataset distillation, which is the task of creating a small synthetic dataset that allows a model trained on this dataset to achieve comparable test accuracy to a model trained on the full dataset. This method leverages long-range training dynamics by matching synthetic data induced parameter trajectories with expert trajectories derived from networks trained on real data. The approach significantly improves upon previous methods by distilling higher-resolution visual data, with substantial empirical results demonstrating superior performance across various datasets including CIFAR-10, CIFAR-100, and Tiny ImageNet.

Introduction

Dataset distillation aims to compress a large dataset into a smaller set of high-information synthetic images while preserving task-specific features necessary for model generalization. Unlike model distillation, which focuses on compressing model complexity, dataset distillation compresses the training data itself. This paper proposes optimizing synthetic data to emulate long-range training characteristics of real data by matching training parameter trajectories (Figure 1). Figure 1

Figure 1: Dataset distillation aims to generate a small synthetic dataset for which a model trained on it can achieve a similar test performance as a model trained on the whole real train set.

The method employs pre-computed expert trajectories recorded from networks trained on the full dataset, using these as a gold standard to guide the distillation process. This long-range trajectory matching addresses challenges such as optimization difficulty and accumulation of error when matching only short-range behavior.

Methodology

The distillation method is centered around matching the synthetic dataset's influence on model training dynamics with that of pre-recorded expert trajectories derived from networks trained on real data (Figure 2). This involves initializing a model from expert parameters at a random time step and training it on the synthetic dataset, followed by penalizing deviations from the expert trajectory. Figure 2

Figure 2: We perform long-range parameter matching between training on distilled synthetic data and training on real data. Starting from the same initial parameters, we train distilled data Dsyn\mathcal{D}_\mathsf{syn} such that NN training steps on them match the same result (in parameter space) from much more MM steps on real data.

The key aspect of the methodology lies in backpropagating through multiple training updates on the synthetic dataset, aligning the final synthetic-trained parameter proximity to the parameter location attained by the expert trajectory after several steps on real data. This trajectory-based distillation approach contrasts with short-range approaches, displaying improvements due to its emphasis on long-term learning features.

Experimental Results

The experiments conducted on CIFAR-10, CIFAR-100, and Tiny ImageNet demonstrate compelling results, with substantial accuracy improvements over existing methods:

  • CIFAR-10: Achieved 46.3% accuracy with a single image per class, compared to the previous state-of-the-art of 28.8% [dsa], and 65.3% accuracy with 10 images per class.
  • CIFAR-100: Enhanced performance to 24.3% accuracy with one image per class, increasing from prior results of 13.9% [dsa].
  • Tiny ImageNet: The method improved accuracy to 8.8% with one image per class, outperforming concurrent work DM [dm].

Furthermore, the approach scaled to distilling higher-resolution 128×128128 \times 128 ImageNet images for the first time, exploring various subsets such as ImageNette and ImageWoof with substantial classification performance gains.

Implications and Future Work

The method's ability to accurately distill datasets and generalize over a long trajectory provides deeper insights into model training dynamics and dataset composition. While promising, challenges remain regarding memory consumption and computational costs associated with expert trajectory training. Future work could explore adaptive selection of expert trajectory segments, further optimizing distillation processes for different model architectures and datasets, potentially facilitating practical applications in privacy-preserving ML and efficient neural architecture search.

Conclusion

This paper presents a dataset distillation framework leveraging long-range trajectory matching, markedly enhancing distillation efficacy and dataset understanding. Through comprehensive experiments, authors illustrate the feasibility of synthesizing compact, high-resolution representative datasets, thereby opening avenues for practical real-world implementation and advancing theoretical discourse in dataset distillation methodologies.

Whiteboard

Glossary

  • Back-propagation: A method for computing gradients through a sequence of operations to update parameters during training. "and back-propagate through the training iterations."
  • Behavior cloning: An imitation learning approach that trains a policy to mimic expert actions directly. "Behavior cloning trains the learning policy to act the same as expert demonstrations."
  • Bi-level optimization: An optimization framework with nested objectives (an inner and outer problem), often used to select or tune model components. "new formulations based on bi-level optimization have shown promising results on applications like continual learning"
  • Catastrophic mode collapse: A failure mode in generative or synthesis processes where outputs lose diversity and collapse into a few modes. "degrading to catastrophic mode collapse in the worst case."
  • Continual learning: Training that aims to learn a sequence of tasks without forgetting earlier ones. "including continual learning"
  • Coreset: A small, representative subset of data designed to preserve performance when training on it instead of the full set. "coreset and instance selection aim to select a subset of the entire training dataset"
  • Coreset selection: The process of choosing a coreset from the full dataset to enable efficient training. "Comparing distillation and coreset selection methods."
  • Cosine distance: A similarity measure based on the cosine of the angle between two vectors, used as a loss or metric. "such as a cosine distance"
  • Dataset distillation: Synthesizing a small dataset so that models trained on it match the performance of training on the full dataset. "Dataset distillation is the task of synthesizing a small dataset such that a model trained on the synthetic set will match the test accuracy of the model trained on the full dataset."
  • Differentiable augmentation: Data augmentation operations that are differentiable, allowing gradients to pass through them to optimize inputs. "where A\mathcal{A} is the differentiable augmentation technique"
  • Differentiable Siamese Augmentation: A differentiable augmentation scheme that applies matched transformations to pairs of inputs; here not used during distillation. "Our method does not use differentiable Siamese augmentation since there is no real data used during the distillation process;"
  • Distribution matching: Aligning the synthetic data distribution to the real data distribution without modeling optimization steps. "instead focusing on distribution matching between synthetic and real data."
  • Expert trajectories: Recorded sequences of model parameters from training on real data, used as targets for imitation. "using expert trajectories τ∗\tau^* to guide the distillation of our synthetic dataset."
  • Federated learning: Training models across decentralized data sources without centralizing raw data. "federated learning"
  • Generative model: A model that learns to produce realistic data samples from a learned distribution. "A related line of research learns a generative model to synthesize training data"
  • Gradient matching: A technique that aligns gradients from synthetic data with those from real data to amplify learning signals. "amplifying learning signal via gradient matching"
  • Hyperparameter optimization: Gradient-based or search-based tuning of training hyperparameters to improve performance. "optimized them using gradient-based hyperparameter optimization"
  • Imitation learning: Learning a policy by mimicking expert demonstrations or behaviors. "Imitation learning attempts to learn a good policy by observing a collection of expert demonstrations"
  • Infinite-width kernel limit: The regime where neural networks behave like kernel machines as width goes to infinity, enabling analytical training via kernels. "optimizing with respect to the infinite-width kernel limit"
  • Instance normalization: A normalization technique that normalizes feature maps per-instance, often used in vision models. "Instance normalization"
  • Instance selection: Choosing specific training examples to form a smaller, effective training set. "coreset and instance selection aim to select a subset of the entire training dataset"
  • Kernel Inducing Point (KIP): A method that distills data using kernel techniques derived from the infinite-width limit. "Kernel Inducing Point (KIP) performs distillation using the infinite-width network limit."
  • Logits: The raw, pre-softmax outputs of a classifier used to compute probabilities. "a single linear layer produces the logits."
  • Meta-learning: Learning to learn; frameworks that optimize across tasks to improve learning efficiency. "meta-learning research"
  • Model distillation: Transferring knowledge from a large or complex model to a smaller one. "Hinton et al.~\cite{hinton2015distilling} proposed model distillation"
  • Neural architecture search: Automatically discovering model architectures that perform well for a given task. "neural architecture search"
  • Neural Tangent Kernel (NTK): A kernel that characterizes infinite-width neural networks, enabling kernel-based training approximations. "where the distilled data is learned with respect to the Neural Tangent Kernel"
  • Parameter space: The high-dimensional space comprising all trainable parameters of a model. "in the parameter space"
  • Parameter trajectory: The sequence of parameter states a model follows during training. "matching segments of parameter trajectories trained on synthetic data"
  • Privacy-preserving ML: Methods that protect sensitive information while training machine learning models. "privacy-preserving ML"
  • SGD with momentum: An optimizer that accelerates SGD by accumulating a velocity vector to damp oscillations and speed convergence. "We use SGD with momentum to optimize Dsyn\mathcal{D}_\mathsf{syn} and α\alpha"
  • Training dynamics: The behavior and evolution of model parameters during optimization. "directly imitate the long-range training dynamics of networks trained on real datasets."
  • Trajectory matching: Aligning the learned parameter trajectory on synthetic data with an expert trajectory from real data. "Dataset Distillation via Trajectory Matching"
  • Weight matching loss: An objective that penalizes differences between student and expert parameters to guide synthetic data optimization. "update our distilled images according to the weight matching loss:"
  • ZCA whitening: A decorrelation and whitening transform that removes linear correlations in data while preserving variance. "we employ ZCA whitening as done in previous work"

Open Problems

We found no open problems mentioned in this paper.

Collections

Sign up for free to add this paper to one or more collections.