Emergent Mind

Neural Networks Learn Statistics of Increasing Complexity

(2402.04362)
Published Feb 6, 2024 in cs.LG

Abstract

The distributional simplicity bias (DSB) posits that neural networks learn low-order moments of the data distribution first, before moving on to higher-order correlations. In this work, we present compelling new evidence for the DSB by showing that networks automatically learn to perform well on maximum-entropy distributions whose low-order statistics match those of the training set early in training, then lose this ability later. We also extend the DSB to discrete domains by proving an equivalence between token $n$-gram frequencies and the moments of embedding vectors, and by finding empirical evidence for the bias in LLMs. Finally we use optimal transport methods to surgically edit the low-order statistics of one class to match those of another, and show that early-training networks treat the edited samples as if they were drawn from the target class. Code is available at https://github.com/EleutherAI/features-across-time.

MNIST learning curves show slight non-monotonicity due to sufficient first and second moment data distribution.

Overview

  • Neural networks exhibit a distributional simplicity bias (DSB), learning simpler data patterns before complex ones.

  • The study analyzes the relationship between data moments and the expected loss in neural networks using Taylor series expansion.

  • Empirical tests show that early in training, networks are sensitive to low-order statistics, with high-order sensitivity developing gradually.

  • The study also explores DSB in discrete domains like language, where it parallels n-gram frequencies with moments of embedding vectors.

Introduction

Neural networks are known for their exceptional capability to adapt to highly complex data and generalize beyond the training set. This adaptability, intriguing particularly because of the networks’ capacity to fit even noisy or random labels, may have its roots in the distributional simplicity bias (DSB). This bias suggests that simpler patterns, or low-order moments of data, are learned first by neural networks before they capture more intricate, higher-order correlations. A new study extends this concept further, delving into how DSB manifests across different data domains and during various phases of model training.

Theory and Methods

The researchers leverage a Taylor series expansion of the expected loss to assert a connection between the moments of data distribution and the expected loss experienced by the network. They posit that if a network reasonably approximates its expected loss through only the first few terms of this series, it would imply that the network is sensitive to the data's moments only up to that particular order. Two criteria are outlined to validate the model's reliance on data moments: transforming low-order statistics across classes must change the model’s classification correspondingly, and tampering with higher-order statistics should not deteriorate performance significantly.

Empirical Findings

This theory is then put to an empirical test across a range of network architectures and datasets, including modified image datasets to predominantly reflect first and second moments their statistical makeup. Key observations include that, early in training, network performance aligns closely with low-order statistics modifications. Sensitivity to these changes lessens as training progresses with models starting to interrelate with higher-order statistics gradually, thereby indicating a dynamic DSB throughout the learning process. Moreover, altering the low-order statistics of images from one class to match another led to the early-training networks classifying these edited images into the target class, essentially misclassifying, hence proving the acknowledged criteria.

Extension to Discrete Domains

A particularly intriguing aspect of the research is its application to discrete domains, namely language, unveiling an equivalence between token n-gram frequencies and the moments of embedding vectors. The study examined Language Models (LLMs) and found the DSB with evidence of a "double descent" effect in performance. Initially, the models showed a U-shaped loss curve for learning n-gram frequencies, followed by a renewed decline in loss, suggesting in-context learning—a phenomenon where models use recent context to make predictions—leading to better performance later on.

Conclusion

Altogether, this research provides a robust empirical backing to the DSB conjecture while breaking new ground in understanding the statistical learning processes of neural networks. It uncovers sequential learning of statistical complexity in models and the interplay of such learning with model architectures and training time. These insights could potentially cultivate methodologies for a more deliberate shaping of learning trajectories in artificial intelligence systems, steering them towards desired generalization behaviors.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.