Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
97 tokens/sec
GPT-4o
53 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
5 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models (2402.19449v2)

Published 29 Feb 2024 in cs.LG, cs.CL, math.OC, and stat.ML

Abstract: Adam has been shown to outperform gradient descent on LLMs by a larger margin than on other tasks, but it is unclear why. We show that a key factor in this performance gap is the heavy-tailed class imbalance found in language tasks. When trained with gradient descent, the loss of infrequent words decreases more slowly than the loss of frequent ones. This leads to a slow decrease on the average loss as most samples come from infrequent words. On the other hand, Adam and sign-based methods are less sensitive to this problem. To establish that this behavior is caused by class imbalance, we show empirically that it can be reproduced across architectures and data types, on language transformers, vision CNNs, and linear models. On a linear model with cross-entropy loss, we show that class imbalance leads to imbalanced, correlated gradients and Hessians that have been hypothesized to benefit Adam. We also prove that, in continuous time, gradient descent converges slowly on low-frequency classes while sign descent does not.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (57)
  1. “Linear attention is (maybe) all you need (to understand transformer optimization)” In arXiv preprint arXiv:2310.01082, 2023 DOI: 10.48550/ARXIV.2310.01082
  2. “An improved algorithm for neural network classification of imbalanced training sets” In IEEE Transactions on Neural Networks 4.6, 1993, pp. 962–969
  3. Jimmy Ba, Jamie Ryan Kiros and Geoffrey E. Hinton “Layer Normalization” In Neural Information Processing Systems (NeurIPS), Deep Learning Symposium., 2016 arXiv:1607.06450 [stat.ML]
  4. “Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients” In International Conference on Machine Learning (ICML) 80, 2018, pp. 413–422 URL: http://proceedings.mlr.press/v80/balles18a.html
  5. Lukas Balles, Fabian Pedregosa and Nicolas Le Roux “The Geometry of Sign Gradient Descent” arXiv/2002.08056, 2020 arXiv:2002.08056 [cs.LG]
  6. “Birth of a Transformer: A Memory Viewpoint” In Neural Information Processing Systems (NeurIPS), 2023
  7. “Language Models are Few-Shot Learners” In Neural Information Processing Systems (NeurIPS), 2020
  8. “What is the effect of importance weighting in deep learning?” In International conference on machine learning (ICML), 2019
  9. “Robustness to Unbounded Smoothness of Generalized SignSGD” In Neural Information Processing Systems (NeurIPS), 2022 URL: https://openreview.net/forum?id=8oj_2Ypp0j
  10. “Class-Balanced Loss Based on Effective Number of Samples” In Conference on Computer Vision and Pattern Recognition (CVPR) Computer Vision Foundation / IEEE, 2019, pp. 9268–9277
  11. John C. Duchi, Elad Hazan and Yoram Singer “Adaptive Subgradient Methods for Online Learning and Stochastic Optimization” In Journal of Machine Learning Research (JMLR) 12, 2011, pp. 2121–2159 URL: http://dl.acm.org/citation.cfm?id=2021068
  12. Vitaly Feldman “Does learning require memorization? a short tale about a long tail” In Proceedings of the 52nd Annual ACM SIGACT Symposium on Theory of Computing, 2020, pp. 954–959
  13. Emanuele Francazi, Marco Baity-Jesi and Aurélien Lucchi “A Theoretical Analysis of the Learning Dynamics under Class Imbalance” In International Conference on Machine Learning (ICML) 202, PMLR PMLR, 2023, pp. 10285–10322
  14. Philip Gage “A new algorithm for data compression” In C Users Journal 12.2 McPherson, KS: R & D Publications, c1987-1994., 1994, pp. 23–38
  15. Nikhil Ghosh, Song Mei and Bin Yu “The Three Stages of Learning Dynamics in High-dimensional Kernel Methods” In International Conference on Learning Representations (ICLR), 2022
  16. “Finding the Optimal Vocabulary Size for Neural Machine Translation” In Findings of the Association for Computational Linguistics (EMNLP), Findings of ACL Association for Computational Linguistics, 2020, pp. 3955–3964 DOI: 10.18653/V1/2020.FINDINGS-EMNLP.352
  17. “Deep Transformers without Shortcuts: Modifying Self-attention for Faithful Signal Propagation” In International Conference on Learning Representations (ICLR) OpenReview.net, 2023 URL: https://openreview.net/pdf?id=NPrsUQgMjKK
  18. “Deep Residual Learning for Image Recognition” In Conference on Computer Vision and Pattern Recognition (CVPR) IEEE Computer Society, 2016, pp. 770–778 DOI: 10.1109/CVPR.2016.90
  19. “Scaling laws and interpretability of learning from repeated data” In arXiv preprint arXiv:2205.10487, 2022
  20. “Learning Deep Representation for Imbalanced Classification” In Conference on Computer Vision and Pattern Recognition (CVPR) IEEE Computer Society, 2016, pp. 5375–5384 DOI: 10.1109/CVPR.2016.580
  21. Kaiqi Jiang, Dhruv Malik and Yuanzhi Li “How Does Adaptive Optimization Impact Local Neural Network Geometry?” In arXiv preprint arXiv:2211.02254, 2022 DOI: 10.48550/ARXIV.2211.02254
  22. Diederik P. Kingma and Jimmy Ba “Adam: A Method for Stochastic Optimization” In International Conference on Learning Representations (ICLR), 2015 URL: http://arxiv.org/abs/1412.6980
  23. Taku Kudo “Subword Regularization: Improving Neural Network Translation Models with Multiple Subword Candidates” In Annual Meeting of the Association for Computational Linguistics ACL, 2018, pp. 66–75
  24. “Noise is not the main factor behind the gap between SGD and Adam on transformers, but sign descent might be” In International Conference on Learning Representations (ICLR), 2023
  25. Frederik Kunstner, Philipp Hennig and Lukas Balles “Limitations of the empirical Fisher approximation for natural gradient descent” In Neural Information Processing Systems (NeurIPS), 2019, pp. 4158–4169
  26. “Gradient-Based Learning Applied to Document Recognition” In Proceedings of the IEEE 86.11, 1998, pp. 2278–2324 DOI: 10.1109/5.726791
  27. “Understanding the Difficulty of Training Transformers” In Conference on Empirical Methods in Natural Language Processing Association for Computational Linguistics, 2020, pp. 5747–5763 DOI: 10.18653/v1/2020.emnlp-main.463
  28. “Large-Scale Long-Tailed Recognition in an Open World” In Conference on Computer Vision and Pattern Recognition (CVPR) Computer Vision Foundation / IEEE, 2019, pp. 2537–2546
  29. Mitchell P. Marcus, Beatrice Santorini and Mary Ann Marcinkiewicz “Building a Large Annotated Corpus of English: The Penn Treebank” In Computational Linguistics 19.2 Cambridge, MA: MIT Press, 1993, pp. 313–330 URL: https://aclanthology.org/J93-2004
  30. “Locating and editing factual associations in GPT” In Neural Information Processing Systems (NeurIPS), 2022
  31. “Pointer Sentinel Mixture Models” In International Conference on Learning Representations (ICLR) OpenReview.net, 2017 URL: https://openreview.net/forum?id=Byj72udxe
  32. “The quantization model of neural scaling” In Neural Information Processing Systems (NeurIPS), 2023
  33. “Convergence of Gradient Descent on Separable Data” In International Conference on Artificial Intelligence and Statistics (AISTATS) 89, PMLR PMLR, 2019, pp. 3420–3428 URL: http://proceedings.mlr.press/v89/nacson19b.html
  34. Preetum Nakkiran, Behnam Neyshabur and Hanie Sedghi “The deep bootstrap framework: Good online learners are good offline generalizers” In International Conference on Learning Representations (ICLR), 2021
  35. “Signal Propagation in Transformers: Theoretical Perspectives and the Role of Rank Collapse” In Neural Information Processing Systems (NeurIPS), 2022
  36. “Vanishing Curvature in Randomly Initialized Deep ReLU Networks” In International Conference on Artificial Intelligence and Statistics (AISTATS) 151, PMLR PMLR, 2022, pp. 7942–7975
  37. “Toward Understanding Why Adam Converges Faster Than SGD for Transformers” NeurIPS 2022 Workshop on Optimization for Machine Learning. arXiv/2306.00204, 2023 DOI: 10.48550/arXiv.2306.00204
  38. “PyTorch: An Imperative Style, High-Performance Deep Learning Library” In Neural Information Processing Systems (NeurIPS), 2019, pp. 8024–8035
  39. Steven T. Piantadosi “Zipf’s word frequency law in natural language: A critical review and future directions” In Psychonomic bulletin & review 21 Springer, 2014, pp. 1112–1130
  40. “Language Models are Unsupervised Multitask Learners” Tech. Report, 2019
  41. “Outliers with Opposing Signals Have an Outsized Effect on Neural Network Optimization” In arXiv preprint arXiv/2311.04163, 2023 DOI: 10.48550/ARXIV.2311.04163
  42. “An investigation of why overparameterization exacerbates spurious correlations” In International Conference on Machine Learning (ICML), 2020
  43. Robin M. Schmidt, Frank Schneider and Philipp Hennig “Descending through a Crowded Valley - Benchmarking Deep Learning Optimizers” In International Conference on Machine Learning (ICML) 139, Proceedings of Machine Learning Research PMLR, 2021, pp. 9367–9376 URL: http://proceedings.mlr.press/v139/schmidt21a.html
  44. Rico Sennrich, Barry Haddow and Alexandra Birch “Neural Machine Translation of Rare Words with Subword Units” In Annual Meeting of the Association for Computational Linguistics The Association for Computer Linguistics, 2016 DOI: 10.18653/v1/p16-1162
  45. Sidak Pal Singh and Dan Alistarh “WoodFisher: Efficient Second-Order Approximation for Neural Network Compression” In Neural Information Processing Systems (NeurIPS), 2020
  46. “The implicit bias of gradient descent on separable data” In Journal of Machine Learning Research (JMLR) 19.1 JMLR. org, 2018, pp. 2822–2878
  47. “Dropout: a simple way to prevent neural networks from overfitting” In Journal of Machine Learning Research (JMLR) 15.1, 2014, pp. 1929–1958
  48. “RMSPROP: Divide the gradient by a running average of its recent magnitude” Lecture notes  http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf, 2012
  49. “Attention is All you Need” In Neural Information Processing Systems (NeurIPS), 2017, pp. 5998–6008
  50. “Interpretability in the wild: a circuit for indirect object identification in GPT-2 small” In arXiv preprint arXiv:2211.00593, 2022
  51. Yu-Xiong Wang, Deva Ramanan and Martial Hebert “Learning to Model the Tail” In Neural Information Processing Systems (NeurIPS), 2017, pp. 7029–7039 URL: https://proceedings.neurips.cc/paper/2017/hash/147ebe637038ca50a1265abac8dea181-Abstract.html
  52. “Zipf’s law holds for phrases, not words” In Scientific reports 5.1 Nature Publishing Group UK London, 2015, pp. 12209
  53. Jingfeng Wu, Vladimir Braverman and Jason D. Lee “Implicit Bias of Gradient Descent for Logistic Regression at the Edge of Stability” In arXiv arXiv:2305.11788, 2023
  54. “Why Gradient Clipping Accelerates Training: A Theoretical Justification for Adaptivity” In International Conference on Learning Representations (ICLR), 2020 URL: https://openreview.net/forum?id=BJgnXpVYwS
  55. “Why are Adaptive Methods Good for Attention Models?” In Neural Information Processing Systems (NeurIPS), 2020, pp. 15383–15393
  56. Xiangxin Zhu, Dragomir Anguelov and Deva Ramanan “Capturing Long-Tail Distributions of Object Subcategories” In Conference on Computer Vision and Pattern Recognition (CVPR) IEEE Computer Society, 2014, pp. 915–922 DOI: 10.1109/CVPR.2014.122
  57. “Tokenization and the Noiseless Channel” In Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), ACL Association for Computational Linguistics, 2023, pp. 5184–5207 DOI: 10.18653/V1/2023.ACL-LONG.284
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (5)
  1. Frederik Kunstner (10 papers)
  2. Robin Yadav (3 papers)
  3. Alan Milligan (2 papers)
  4. Mark Schmidt (74 papers)
  5. Alberto Bietti (35 papers)
Citations (14)

Summary

  • The paper shows that Adam’s adaptive preconditioning mitigates heavy-tailed class imbalance, ensuring uniform learning across rare classes.
  • Experimental results reveal that SGD struggles with infrequent classes while Adam achieves consistent training loss reduction across model types.
  • The study's insights offer practical modifications for SGD and guide future research on optimizer designs for imbalanced datasets.

Heavy-Tailed Class Imbalance: Exploring Adam's Superiority over Gradient Descent in LLMs

Introduction

The optimization of LLMs is crucial for advancing the field of NLP. An interesting observation made in recent times is the distinct advantage that the Adam optimizer holds over traditional stochastic gradient descent (SGD) when training these models. The paper discussed here explores understanding this phenomenon, attributing the performance disparity to the heavy-tailed class imbalance inherent in LLMling tasks.

Heavy-Tailed Class Imbalance

Language data characteristically displays a heavy-tailed class distribution, where a significant number of classes (or words) are relatively infrequent. Traditional gradient descent methods tend to make slow progress on these low-frequency classes, negatively impacting overall training efficiency. Contrarily, Adam and similar sign-based methods do not exhibit this limitation, thereby facilitating uniform class learning speeds. The researchers empirically substantiate their argument through experiments across various models—including language transformers and vision CNNs—highlighting the generalizability of their findings beyond language data.

Experimental Insights

The distinction between Adam and SGD becomes particularly pronounced when observing training performance disaggregated by class frequency. Experiments demonstrate that while SGD struggles with low-frequency classes—barely making progress—the training loss for these classes reduces much more uniformly under Adam. This behavior persists across different architectures and data types, reinforcing the core thesis that heavy-tailed class imbalance significantly contributes to the optimization gap between Adam and SGD. Intriguingly, the implementation of simpler optimizers, such as sign descent, reveals that altering the update direction rather than magnitude (as done by Adam) is essential for mitigating class imbalance effects.

Theoretical Perspectives

On a linear model exhibiting heavy-tailed class imbalance, it was evidenced that both the scale of gradient and Hessian reflect class frequencies, which leads to ill-conditioning—a situation where gradient descent performance degrades due to vastly different convergence speeds across classes. Adam's efficiency, in this context, could be partially attributed to its preconditioning capability, which approximately counteracts the ill-conditioning by normalizing gradient magnitudes. This finding suggests that, at least for softmax classification on linear models, Adam indirectly caters to the differential scaling induced by class frequencies, facilitating a more balanced training dynamic.

Broader Implications

This paper not only elucidates why Adam outperforms SGD in the context of LLMs but also sheds light on potential improvements across various fields where class imbalance is prevalent. The insights provided could lead to the development of new optimization algorithms or adjustments to existing ones—especially in tasks beyond LLMing. Moreover, the demonstrated effectiveness of simple modifications, such as loss reweighting, provides practical avenues for enhancing SGD's performance, narrowing the gap with Adam.

Future Directions

The comprehensive analysis presented sparks a plethora of questions for future research. Specifically, understanding the full ramifications of heavy-tailed class imbalance on model generalization and exploring other model architectures where similar optimization dynamics might be at play are compelling directions. The observed correlation between gradient scale and Hessian in the context of class frequencies also opens up theoretical avenues for developing novel optimizers or enhancing existing ones to leverage this relationship more explicitly.

In summary, the paper provides a thorough examination of the challenges posed by heavy-tailed class imbalance in optimizing LLMs, revealing the underlying reasons for Adam's superiority over SGD. Acting on these insights can not only improve the training efficiency of LLMs but also inform optimization strategies in other domains facing similar issues.