Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
162 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Early Weight Averaging meets High Learning Rates for LLM Pre-training (2306.03241v2)

Published 5 Jun 2023 in cs.LG, cs.AI, and cs.CL

Abstract: Training LLMs incurs significant cost; hence, any strategy that accelerates model convergence is helpful. In this paper, we investigate the ability of a simple idea checkpoint averaging along the trajectory of a training run to improve both convergence and generalization quite early on during training. Here we show that models trained with high learning rates observe higher gains due to checkpoint averaging. Furthermore, these gains are amplified when checkpoints are sampled with considerable spacing in training steps. Our training recipe outperforms conventional training and popular checkpoint averaging baselines such as exponential moving average (EMA) and stochastic moving average (SWA). We evaluate our training recipe by pre-training LLMs, where high learning rates are inherently preferred due to extremely large batch sizes. Specifically, we pre-trained nanoGPT-2 models of varying sizes, small (125M), medium (335M), and large (770M)on the OpenWebText dataset, comprised of 9B tokens. Additionally, we present results for publicly available Pythia LLMs, ranging from 1B to 12B, which were trained on the PILE-deduped dataset containing 207B tokens.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (44)
  1. There are many consistent explanations of unlabeled data: Why you should average. International Conference on Learning Representations, 2018a.
  2. Improving consistency-based semi-supervised learning with weight averaging. arXiv preprint arXiv:1806.05594, 2(9):11, 2018b.
  3. Pythia: A suite for analyzing large language models across training and scaling. arXiv preprint arXiv: Arxiv-2304.01373, 2023.
  4. Swad: Domain generalization by seeking flat minima. Advances in Neural Information Processing Systems, 34:22405–22418, 2021.
  5. Palm: Scaling language modeling with pathways. GOOGLE, 2022.
  6. Think you have solved question answering? try arc, the ai2 reasoning challenge. ArXiv, abs/1803.05457, 2018.
  7. Cerebras-gpt: Open compute-optimal language models trained on the cerebras wafer-scale cluster. arXiv preprint arXiv: Arxiv-2304.03208, 2023.
  8. The pile: An 800gb dataset of diverse text for language modeling. arXiv preprint arXiv: Arxiv-2101.00027, 2020.
  9. Accurate, large minibatch sgd: Training imagenet in 1 hour. arXiv preprint arXiv: Arxiv-1706.02677, 2017.
  10. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851, 2020.
  11. Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.
  12. Averaging weights leads to wider optima and better generalization. In Amir Globerson and Ricardo Silva (eds.), Proceedings of the Thirty-Fourth Conference on Uncertainty in Artificial Intelligence, UAI 2018, Monterey, California, USA, August 6-10, 2018, pp.  876–885. AUAI Press, 2018. URL http://auai.org/uai2018/proceedings/papers/313.pdf.
  13. Population parameter averaging (papa). arXiv preprint arXiv: Arxiv-2304.03094, 2023.
  14. Jean Kaddour. Stop wasting my time! saving days of imagenet and bert training with latest weight averaging. arXiv preprint arXiv:2209.14981, 2022.
  15. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020.
  16. On the maximum hessian eigenvalue and generalization. ICBINB, 2022. doi: 10.48550/arXiv.2206.10654.
  17. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  18. Alex Krizhevsky. One weird trick for parallelizing convolutional neural networks. arXiv preprint arXiv: 1404.5997, 2014.
  19. Simple and scalable predictive uncertainty estimation using deep ensembles. NIPS, 2016.
  20. Linear stochastic approximation: How far does constant step-size and iterate averaging go? In International Conference on Artificial Intelligence and Statistics, pp.  1347–1355. PMLR, 2018.
  21. Deduplicating training data makes language models better. arXiv preprint arXiv: Arxiv-2107.06499, 2021.
  22. Train large, then compress: Rethinking model size for efficient training and inference of transformers. International Conference On Machine Learning, 2020.
  23. Sophia: A scalable stochastic second-order optimizer for language model pre-training. arXiv preprint arXiv: 2305.14342, 2023.
  24. On the sdes and scaling rules for adaptive gradient algorithms. Neural Information Processing Systems, 2022. doi: 10.48550/arXiv.2205.10287.
  25. Pointer sentinel mixture models, 2016. URL https://arxiv.org/abs/1609.07843.
  26. Iterate averaging as regularization for stochastic gradient descent. In Conference On Learning Theory, pp.  3222–3242. PMLR, 2018.
  27. The lambada dataset: Word prediction requiring a broad discourse context. Annual Meeting Of The Association For Computational Linguistics, 2016. doi: 10.18653/v1/P16-1144.
  28. Acceleration of stochastic approximation by averaging. SIAM journal on control and optimization, 30(4):838–855, 1992.
  29. Model ratatouille: Recycling diverse models for out-of-distribution generalization.
  30. Diverse weight averaging for out-of-distribution generalization. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho (eds.), Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=tq_J_MqB3UB.
  31. Pangu-ΣΣ\Sigmaroman_Σ: Towards trillion parameter language model with sparse heterogeneous computing. ARXIV.ORG, 2023. doi: 10.48550/arXiv.2303.10845.
  32. U-net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18, pp.  234–241. Springer, 2015.
  33. Training trajectories, mini-batch losses and the curious role of the learning rate. arXiv preprint arXiv:2301.02312, 2023.
  34. A mathematical exploration of why language models help solve downstream tasks. International Conference On Learning Representations, 2020.
  35. Rethinking the inception architecture for computer vision. Computer Vision And Pattern Recognition, 2015. doi: 10.1109/CVPR.2016.308.
  36. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. Advances in neural information processing systems, 30, 2017.
  37. Llama: Open and efficient foundation language models. ARXIV, 2023a.
  38. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv: 2307.09288, 2023b.
  39. Crowdsourcing multiple choice science questions. ArXiv, abs/1707.06209, 2017.
  40. Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time. In International Conference on Machine Learning, pp.  23965–23998. PMLR, 2022.
  41. Stable and low-precision training for large-scale vision-language models. arXiv preprint arXiv: Arxiv-2304.13013, 2023.
  42. Training trajectories of language models across scales. ARXIV.ORG, 2022. doi: 10.48550/arXiv.2212.09803.
  43. Swalp: Stochastic weight averaging in low precision training. In International Conference on Machine Learning, pp.  7015–7024. PMLR, 2019.
  44. Opt: Open pre-trained transformer language models. ARXIV.ORG, 2022.
Citations (14)

Summary

  • The paper introduces LAWA, an early checkpoint averaging method that mitigates high learning rate oscillations to enhance convergence and generalization.
  • The study demonstrates significant efficiency gains and reduced GPU time in experiments with nanoGPT-2 and Pythia models.
  • The results show that LAWA improves zero-shot performance and holds promise for broader applications including federated and diffusion model training.

An Analytical Review of "Early Weight Averaging meets High Learning Rates for LLM Pre-training"

The paper "Early Weight Averaging meets High Learning Rates for LLM Pre-training" adopts a pragmatic approach to improve the efficiency and generalization of training LLMs. By introducing a method termed "Latest Weight Averaging" (LAWA), the authors aim to address computational constraints associated with training LLMs using high learning rates, a scenario commonly occurring in cluster-based training environments with large batch sizes.

Key Contributions and Methodology

The essence of this paper lies in utilizing checkpoint averaging throughout the training trajectory rather than merely in the later phases, as seen in previous methods like Stochastic Weight Averaging (SWA). The innovation of LAWA is its integration early in the training process, thereby helping mitigate oscillations typically observed with high learning rates, without adjusting the learning rate schedule. This results in better generalization and faster convergence, allowing LLMs to retain performance while significantly reducing the computational budget.

Significant contributions of this work include:

  1. Empirical Exploration: The authors conducted controlled experiments on nanoGPT-2 and Pythia LLMs, exploring the interaction between high learning rates and weight averaging, demonstrating notable improvements in convergence speed and generalization.
  2. Improved Training Efficiency: By using large models like nanoGPT-2 (up to 770M parameters) and Pythia (up to 12B parameters), the paper effectively showcases the practical benefits of LAWA, highlighting its impact on reducing GPU hours while maintaining or even enhancing performance.
  3. Versatility in Application: Extending beyond LLMs, LAWA’s potential was evaluated with a diffusion model for image generation, illustrating its broader applicability in generative model training.
  4. Comprehensive Zero-Shot Performance Analysis: By assessing the zero-shot performance on multiple downstream tasks, the authors provide compelling evidence of LAWA's effectiveness in enhancing model capabilities early in the training cycle.

Numerical Results

The numerical results presented are robust and indicative of significant performance gains:

  • LAWA-driven models showed improved log perplexity during early to mid-training compared to both conventional methods and established baselines like EMA and SWA.
  • Savings in GPU time are quantified comprehensively, with notable reductions in compute time contributing to resource-efficient model development.

Theoretical and Practical Implications

From a theoretical perspective, LAWA adds a layer of abstraction to the relationship between learning rate scheduling and generalization performance. The methodology suggests that managing the diversity of model checkpoints through strategic averaging can circumvent the detrimental effects of high learning rate oscillations.

Practically, LAWA’s seamless integration into existing pre-training pipelines makes it a valuable technique for institutions seeking to optimize training durations without compromising model performance. This feature is particularly crucial in large-year operations where compute costs are non-negligible.

Future Directions

Looking forward, LAWA’s integration in federated and continual learning paradigms holds promise. The technique could serve as a cornerstone for developing more resilient model adaptation strategies, especially in non-stationary environments or with sequential data inputs.

In conclusion, this paper presents a valuable contribution to the field of efficient large-scale model training. It opens avenues for further exploration in optimizing computational budgets, using theoretical insights from diverse model checkpoints to enhance generalization capabilities.

Youtube Logo Streamline Icon: https://streamlinehq.com