Emergent Mind

Deep Grokking: Would Deep Neural Networks Generalize Better?

(2405.19454)
Published May 29, 2024 in cs.LG

Abstract

Recent research on the grokking phenomenon has illuminated the intricacies of neural networks' training dynamics and their generalization behaviors. Grokking refers to a sharp rise of the network's generalization accuracy on the test set, which occurs long after an extended overfitting phase, during which the network perfectly fits the training set. While the existing research primarily focus on shallow networks such as 2-layer MLP and 1-layer Transformer, we explore grokking on deep networks (e.g. 12-layer MLP). We empirically replicate the phenomenon and find that deep neural networks can be more susceptible to grokking than its shallower counterparts. Meanwhile, we observe an intriguing multi-stage generalization phenomenon when increase the depth of the MLP model where the test accuracy exhibits a secondary surge, which is scarcely seen on shallow models. We further uncover compelling correspondences between the decreasing of feature ranks and the phase transition from overfitting to the generalization stage during grokking. Additionally, we find that the multi-stage generalization phenomenon often aligns with a double-descent pattern in feature ranks. These observations suggest that internal feature rank could serve as a more promising indicator of the model's generalization behavior compared to the weight-norm. We believe our work is the first one to dive into grokking in deep neural networks, and investigate the relationship of feature rank and generalization performance.

Generalization and feature learning of a 12-layer MLP with varying training data amounts.

Overview

  • The study examines 'grokking' in deep neural networks, specifically finding that deeper networks (e.g., 12-layer MLPs) are more prone to this phenomenon, which is characterized by sharp improvements in test accuracy following extensive phases of overfitting.

  • It reveals a multi-stage generalization effect in deep networks, where multiple surges in test accuracy occur, aligning with a double-descent pattern in feature ranks, contrasting with the single surge observed in shallower networks.

  • The research highlights a strong correlation between decreases in feature ranks and the transition from overfitting to generalization, suggesting that monitoring feature rank dynamics could provide better indicators of generalization phase transitions than weight-norms.

Deep Grokking: Would Deep Neural Networks Generalize Better?

The paper "Deep Grokking: Would Deep Neural Networks Generalize Better?" by Simin Fan, Razvan Pascanu, and Martin Jaggi explores the phenomenon of "grokking" in deep neural networks. Grokking refers to the abrupt improvement in test set accuracy after an extensive phase of overfitting to the training set. This study provides new insights into the generalization behaviors of deep neural networks, specifically focusing on deep multi-layer perceptrons (MLPs).

Summary of Findings

Grokking in Deep Networks:

  • The authors report that deep neural networks (e.g., 12-layer MLP) are more prone to grokking compared to shallower models. This phenomenon was empirically validated, demonstrating that deeper networks exhibit a sharp phase transition to high test accuracy long after achieving high training accuracy.

Multi-stage Generalization:

  • An intriguing finding is the discovery of a multi-stage generalization effect in deep networks. Unlike shallow networks, which typically display a single surge in test accuracy, deeper networks may show multiple distinct periods of sharp test accuracy improvements. This multi-stage phenomenon is aligned with a double-descent pattern in feature ranks.

Feature Ranks and Generalization:

  • The study identifies a strong correlation between the decrease in feature ranks and the transition from overfitting to generalization. This suggests that the internal feature ranks could serve as a more reliable indicator of a model's generalization phase transition compared to weight-norms.

Emergence of the "Tunnel":

  • The paper corroborates findings from prior research suggesting that deeper networks develop a "tunnel" where later layers compress feature representations into low-rank forms. This compression effect may explain the susceptibility of deeper networks to more pronounced grokking behavior.

Experimental Setup

The authors used MLP networks with varying depths on the MNIST dataset, applying large initialization and small weight decay to replicate the grokking phenomenon. They trained the models using the Adam optimizer with a learning rate of 1x10-3 and Mean Square Error (MSE) for 100,000 steps. The study employed methods such as linear probing accuracy and numerical rank estimation of internal features to analyze feature representations.

Key Observations

Deeper Networks and Grokking:

  • Deeper MLPs tend to exhibit delayed growth in both training and test accuracies, indicating a more severe overfitting phase followed by a sharp generalization transition.

Impact of Training Data Size:

  • Small training sets induce severe grokking, while larger training sets mitigate this effect, resulting in earlier phase transitions.

Rank Collapse and Generalization:

  • Consistently across different models and training setups, the rise in test accuracy coincides with a significant drop in feature ranks, proposing that rank dynamics provide a nuanced understanding of the model's generalization process.

Weight-Norm and Phase Transitions:

  • The trajectories of weight-norms do not provide clear indications of phase transitions, challenging the hypothesis that weight-norm is a reliable indicator for grokking.

Practical and Theoretical Implications

This study provides a comprehensive empirical basis for understanding the complex training dynamics of deep networks. Practically, it suggests that monitoring feature rank dynamics can offer more precise indicators of model phase transitions during training. Theoretically, it challenges existing notions about the relationship between model initialization, regularization, and generalization.

The paper's findings imply that further exploration into the internal feature dynamics of neural networks can enhance our understanding of generalization behaviors. Future work might focus on validating these observations across other deep learning architectures like Transformers and RNNs, and extending the theoretical frameworks to account for the novel multi-stage generalization phenomena observed.

Conclusion

In conclusion, "Deep Grokking: Would Deep Neural Networks Generalize Better?" provides significant insights into the generalization behaviors of deep MLPs. It establishes a critical connection between feature rank dynamics and phase transitions during training, proposing that deep networks are inherently more susceptible to grokking. The multi-stage generalization phenomenon and the emergence of the "tunnel" effect offer new avenues for both theoretical investigation and practical monitoring of neural network training dynamics.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube