How to set AdamW's weight decay as you scale model and dataset size (2405.13698v1)
Abstract: We show that weights learned by AdamW can be understood as an exponential moving average (EMA) of recent updates. This gives critical insights for how to set the weight decay in AdamW, and how the weight decay should scale with model and dataset size. In particular, the key hyperparameter for an exponential moving average is the EMA timescale. Intuitively, the EMA timescale can be understood as the number of recent iterations the EMA averages over. Given a fixed learning rate, there is a one-to-one mapping from the EMA timescale to the usual weight decay hyperparameter. Thus, choosing an EMA timescale implicitly sets the weight decay. Importantly, there are natural guidelines for sensible values for the EMA timescale: we need to average over all datapoints, so the EMA timescale should not be (much) smaller than 1 epoch, and we need to forget early updates, so the EMA timescale should not be (much) bigger than the total number of training epochs. In our experiments, we find that optimal EMA timescales are consistent with these guidelines, as are the hyperparameters chosen in recent large-scale LLM pretraining runs (e.g.\ Llama 1+2 and Stable LM). Critically, these guidelines suggest that the optimal EMA timescale should not change (much) as we scale the model and dataset. That implies that as the dataset size increases, the optimal weight decay should fall. Moreover, as the model size increases, the optimal weight decay should also increase (if we follow the muP recommendation for scaling the learning rate).
- Theoretical analysis of auto rate-tuning by batch normalization. In ICLR, 2019.
- How to scale your ema. In NeurIPS, 2023.
- Symbolic discovery of optimization algorithms. In NeurIPS, 2023.
- A downsampled variant of imagenet as an alternative to the cifar datasets. arXiv preprint arXiv:1707.08819, 2017.
- Autoaugment: Learning augmentation strategies from data. In CVPR, 2019.
- Scaling vision transformers to 22 billion parameters. In ICML, 2023.
- An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
- Deep residual learning for image recognition. In CVPR, 2016.
- Highly scalable deep learning training system with mixed-precision: Training imagenet in four minutes. arXiv preprint arXiv:1807.11205, 2018.
- Rotational equilibrium: How weight decay balances learning across neural networks. In NeurIPS 2023 Workshop on Mathematics of Modern Machine Learning, 2023.
- Krizhevsky, A. Learning multiple layers of features from tiny images. Technical report, 2009.
- An exponential learning rate schedule for deep learning. In ICLR, 2019.
- Reconciling modern deep learning with traditional optimization analyses: The intrinsic learning rate. In NeurIPS, 2020.
- Robust training of neural networks using scale invariant architectures. In ICML, 2022.
- Lingle, L. A large-scale exploration of µ-transfer. 2024.
- Sophia: A scalable stochastic second-order optimizer for language model pre-training. In ICLR, 2024.
- Decoupled weight decay regularization. In ICLR, 2018.
- When does label smoothing help? In NeurIPS, 2019.
- LLaMA: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023a.
- LLaMA 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023b.
- Technical report for stablelm-3b-4e1t. Technical Report, 2023.
- Van Laarhoven, T. L2 regularization versus batch and weight normalization. arXiv preprint arXiv:1706.05350, 2017.
- Spherical motion dynamics: Learning dynamics of normalized neural network using sgd and weight decay. 2021.
- Small-scale proxies for large-scale transformer training instabilities. In ICLR, 2024.
- Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer. arXiv preprint arXiv:2203.03466, 2022.
- Three mechanisms of weight decay regularization. In ICLR, 2019.
- Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
- Xi Wang (275 papers)
- Laurence Aitchison (66 papers)