Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
110 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

How to set AdamW's weight decay as you scale model and dataset size (2405.13698v1)

Published 22 May 2024 in cs.LG and cs.AI

Abstract: We show that weights learned by AdamW can be understood as an exponential moving average (EMA) of recent updates. This gives critical insights for how to set the weight decay in AdamW, and how the weight decay should scale with model and dataset size. In particular, the key hyperparameter for an exponential moving average is the EMA timescale. Intuitively, the EMA timescale can be understood as the number of recent iterations the EMA averages over. Given a fixed learning rate, there is a one-to-one mapping from the EMA timescale to the usual weight decay hyperparameter. Thus, choosing an EMA timescale implicitly sets the weight decay. Importantly, there are natural guidelines for sensible values for the EMA timescale: we need to average over all datapoints, so the EMA timescale should not be (much) smaller than 1 epoch, and we need to forget early updates, so the EMA timescale should not be (much) bigger than the total number of training epochs. In our experiments, we find that optimal EMA timescales are consistent with these guidelines, as are the hyperparameters chosen in recent large-scale LLM pretraining runs (e.g.\ Llama 1+2 and Stable LM). Critically, these guidelines suggest that the optimal EMA timescale should not change (much) as we scale the model and dataset. That implies that as the dataset size increases, the optimal weight decay should fall. Moreover, as the model size increases, the optimal weight decay should also increase (if we follow the muP recommendation for scaling the learning rate).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (27)
  1. Theoretical analysis of auto rate-tuning by batch normalization. In ICLR, 2019.
  2. How to scale your ema. In NeurIPS, 2023.
  3. Symbolic discovery of optimization algorithms. In NeurIPS, 2023.
  4. A downsampled variant of imagenet as an alternative to the cifar datasets. arXiv preprint arXiv:1707.08819, 2017.
  5. Autoaugment: Learning augmentation strategies from data. In CVPR, 2019.
  6. Scaling vision transformers to 22 billion parameters. In ICML, 2023.
  7. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
  8. Deep residual learning for image recognition. In CVPR, 2016.
  9. Highly scalable deep learning training system with mixed-precision: Training imagenet in four minutes. arXiv preprint arXiv:1807.11205, 2018.
  10. Rotational equilibrium: How weight decay balances learning across neural networks. In NeurIPS 2023 Workshop on Mathematics of Modern Machine Learning, 2023.
  11. Krizhevsky, A. Learning multiple layers of features from tiny images. Technical report, 2009.
  12. An exponential learning rate schedule for deep learning. In ICLR, 2019.
  13. Reconciling modern deep learning with traditional optimization analyses: The intrinsic learning rate. In NeurIPS, 2020.
  14. Robust training of neural networks using scale invariant architectures. In ICML, 2022.
  15. Lingle, L. A large-scale exploration of µ-transfer. 2024.
  16. Sophia: A scalable stochastic second-order optimizer for language model pre-training. In ICLR, 2024.
  17. Decoupled weight decay regularization. In ICLR, 2018.
  18. When does label smoothing help? In NeurIPS, 2019.
  19. LLaMA: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023a.
  20. LLaMA 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023b.
  21. Technical report for stablelm-3b-4e1t. Technical Report, 2023.
  22. Van Laarhoven, T. L2 regularization versus batch and weight normalization. arXiv preprint arXiv:1706.05350, 2017.
  23. Spherical motion dynamics: Learning dynamics of normalized neural network using sgd and weight decay. 2021.
  24. Small-scale proxies for large-scale transformer training instabilities. In ICLR, 2024.
  25. Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer. arXiv preprint arXiv:2203.03466, 2022.
  26. Three mechanisms of weight decay regularization. In ICLR, 2019.
  27. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (2)
  1. Xi Wang (275 papers)
  2. Laurence Aitchison (66 papers)
Citations (3)

Summary

  • The paper introduces an EMA timescale guideline linking weight decay and learning rate to provide a theoretical basis for tuning AdamW.
  • Empirical results reveal that optimal weight decay decreases with larger datasets and increases with model size, following μP scaling principles.
  • The study offers a practical framework for adjusting hyperparameters in large-scale training, potentially enhancing model performance and efficiency.

Understanding AdamW's Weight Decay Scaling with Model and Dataset Size

The paper, titled "How to set AdamW's weight decay as you scale model and dataset size," explores a nuanced aspect of neural network training—understanding and setting the weight decay hyperparameter in AdamW optimizations. This investigation is anchored in elucidating the relationship between weight decay and model scaling, predominantly in massive LLM training scenarios.

AdamW, an optimization algorithm widely employed due to its incorporation of weight decay, can potentially be examined through the lens of an exponential moving average (EMA). Weight decay in AdamW influences how recent updates are averaged akin to what is observed in EMA frameworks, thus opening avenues for innovative approaches to hyperparameter setting, especially weight decay. The critical insight presented is that the customary weight decay hyperparameter, traditionally adjusted manually or through empirical testing, is interdependently linked with the learning rate via the EMA timescale parameter, τiter=1/(ηλ)\tau_{\text{iter}} = 1/(\eta \lambda), where η\eta is the learning rate, and λ\lambda is the weight decay.

Key Findings

  1. EMA Timescale as a Guideline: The paper posits that setting a sensible value for the EMA timescale τiter\tau_{\text{iter}} can inherently dictate the weight decay parameter. Practical guidelines suggest that the EMA timescale, measured in epochs, should balance between not being smaller than one epoch nor excessively larger than the total training epochs to effectively average over data points while consistently forgetting the initial updates.
  2. Scalability Across Models and Datasets: A salient implication of the EMA viewpoint is that as the scale of models and datasets varies, the optimal EMA timescale τepoch\tau_{\text{epoch}} remains relatively unchanged. Therefore, with increasing dataset sizes, the implied recommendation is a reduction in the optimal weight decay. Conversely, as model sizes enhance, respecting the μP\mu P scaling for learning rates indicates a necessary increase in weight decay.
  3. Validation Through Experiments: Empirical tests across different model architectures such as ResNet and ViT validate the hypothesis—especially the robustness of the optimal EMA timescale guideline across varied hyperparameter settings. Furthermore, comparisons with configurations used in large-scale model pretraining (e.g., Llama series) underscore the potential applicability of the EMA-derived scaling principles.

Theoretical and Practical Implications

Theoretically, this paper enhances the understanding of AdamW’s role in weight optimization beyond empirical adjustments, embedding it within a theoretically grounded framework linked with EMA processes. This advancement could fuel further exploration into optimizing other hyperparameters potentially governed by exponential averaging concepts. Moreover, the paper's empirical insights may spur developers to systematically adjust their weight decay settings, potentially boosting the efficiency and performance of large-scale model training efforts.

Looking forward, further investigations may venture into leveraging this EMA perspective for other optimization algorithms that incorporate decoupled weight decay mechanisms. Additionally, the exploration and validation of the resulting insights in more diverse model architectures and across a spectrum of tasks can strengthen and refine the practical guidelines offered.

In conclusion, the paper raises important considerations on how model scaling interacts with hyperparameter optimization, predominantly weight decay, thus offering a structured approach to hyperparameter scaling based on theoretical and empirical examinations. Such insights could prove particularly beneficial in refining current practices in model training, optimizing computational resources, and enhancing model performance.