Papers
Topics
Authors
Recent
Assistant
AI Research Assistant
Well-researched responses based on relevant abstracts and paper content.
Custom Instructions Pro
Preferences or requirements that you'd like Emergent Mind to consider when generating responses.
Gemini 2.5 Flash
Gemini 2.5 Flash 160 tok/s
Gemini 2.5 Pro 50 tok/s Pro
GPT-5 Medium 33 tok/s Pro
GPT-5 High 41 tok/s Pro
GPT-4o 95 tok/s Pro
Kimi K2 193 tok/s Pro
GPT OSS 120B 417 tok/s Pro
Claude Sonnet 4.5 39 tok/s Pro
2000 character limit reached

Early Weight Averaging meets High Learning Rates for LLM Pre-training (2306.03241v2)

Published 5 Jun 2023 in cs.LG, cs.AI, and cs.CL

Abstract: Training LLMs incurs significant cost; hence, any strategy that accelerates model convergence is helpful. In this paper, we investigate the ability of a simple idea checkpoint averaging along the trajectory of a training run to improve both convergence and generalization quite early on during training. Here we show that models trained with high learning rates observe higher gains due to checkpoint averaging. Furthermore, these gains are amplified when checkpoints are sampled with considerable spacing in training steps. Our training recipe outperforms conventional training and popular checkpoint averaging baselines such as exponential moving average (EMA) and stochastic moving average (SWA). We evaluate our training recipe by pre-training LLMs, where high learning rates are inherently preferred due to extremely large batch sizes. Specifically, we pre-trained nanoGPT-2 models of varying sizes, small (125M), medium (335M), and large (770M)on the OpenWebText dataset, comprised of 9B tokens. Additionally, we present results for publicly available Pythia LLMs, ranging from 1B to 12B, which were trained on the PILE-deduped dataset containing 207B tokens.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (44)
  1. There are many consistent explanations of unlabeled data: Why you should average. International Conference on Learning Representations, 2018a.
  2. Improving consistency-based semi-supervised learning with weight averaging. arXiv preprint arXiv:1806.05594, 2(9):11, 2018b.
  3. Pythia: A suite for analyzing large language models across training and scaling. arXiv preprint arXiv: Arxiv-2304.01373, 2023.
  4. Swad: Domain generalization by seeking flat minima. Advances in Neural Information Processing Systems, 34:22405–22418, 2021.
  5. Palm: Scaling language modeling with pathways. GOOGLE, 2022.
  6. Think you have solved question answering? try arc, the ai2 reasoning challenge. ArXiv, abs/1803.05457, 2018.
  7. Cerebras-gpt: Open compute-optimal language models trained on the cerebras wafer-scale cluster. arXiv preprint arXiv: Arxiv-2304.03208, 2023.
  8. The pile: An 800gb dataset of diverse text for language modeling. arXiv preprint arXiv: Arxiv-2101.00027, 2020.
  9. Accurate, large minibatch sgd: Training imagenet in 1 hour. arXiv preprint arXiv: Arxiv-1706.02677, 2017.
  10. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851, 2020.
  11. Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.
  12. Averaging weights leads to wider optima and better generalization. In Amir Globerson and Ricardo Silva (eds.), Proceedings of the Thirty-Fourth Conference on Uncertainty in Artificial Intelligence, UAI 2018, Monterey, California, USA, August 6-10, 2018, pp.  876–885. AUAI Press, 2018. URL http://auai.org/uai2018/proceedings/papers/313.pdf.
  13. Population parameter averaging (papa). arXiv preprint arXiv: Arxiv-2304.03094, 2023.
  14. Jean Kaddour. Stop wasting my time! saving days of imagenet and bert training with latest weight averaging. arXiv preprint arXiv:2209.14981, 2022.
  15. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020.
  16. On the maximum hessian eigenvalue and generalization. ICBINB, 2022. doi: 10.48550/arXiv.2206.10654.
  17. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  18. Alex Krizhevsky. One weird trick for parallelizing convolutional neural networks. arXiv preprint arXiv: 1404.5997, 2014.
  19. Simple and scalable predictive uncertainty estimation using deep ensembles. NIPS, 2016.
  20. Linear stochastic approximation: How far does constant step-size and iterate averaging go? In International Conference on Artificial Intelligence and Statistics, pp.  1347–1355. PMLR, 2018.
  21. Deduplicating training data makes language models better. arXiv preprint arXiv: Arxiv-2107.06499, 2021.
  22. Train large, then compress: Rethinking model size for efficient training and inference of transformers. International Conference On Machine Learning, 2020.
  23. Sophia: A scalable stochastic second-order optimizer for language model pre-training. arXiv preprint arXiv: 2305.14342, 2023.
  24. On the sdes and scaling rules for adaptive gradient algorithms. Neural Information Processing Systems, 2022. doi: 10.48550/arXiv.2205.10287.
  25. Pointer sentinel mixture models, 2016. URL https://arxiv.org/abs/1609.07843.
  26. Iterate averaging as regularization for stochastic gradient descent. In Conference On Learning Theory, pp.  3222–3242. PMLR, 2018.
  27. The lambada dataset: Word prediction requiring a broad discourse context. Annual Meeting Of The Association For Computational Linguistics, 2016. doi: 10.18653/v1/P16-1144.
  28. Acceleration of stochastic approximation by averaging. SIAM journal on control and optimization, 30(4):838–855, 1992.
  29. Model ratatouille: Recycling diverse models for out-of-distribution generalization.
  30. Diverse weight averaging for out-of-distribution generalization. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho (eds.), Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=tq_J_MqB3UB.
  31. Pangu-ΣΣ\Sigmaroman_Σ: Towards trillion parameter language model with sparse heterogeneous computing. ARXIV.ORG, 2023. doi: 10.48550/arXiv.2303.10845.
  32. U-net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18, pp.  234–241. Springer, 2015.
  33. Training trajectories, mini-batch losses and the curious role of the learning rate. arXiv preprint arXiv:2301.02312, 2023.
  34. A mathematical exploration of why language models help solve downstream tasks. International Conference On Learning Representations, 2020.
  35. Rethinking the inception architecture for computer vision. Computer Vision And Pattern Recognition, 2015. doi: 10.1109/CVPR.2016.308.
  36. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. Advances in neural information processing systems, 30, 2017.
  37. Llama: Open and efficient foundation language models. ARXIV, 2023a.
  38. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv: 2307.09288, 2023b.
  39. Crowdsourcing multiple choice science questions. ArXiv, abs/1707.06209, 2017.
  40. Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time. In International Conference on Machine Learning, pp.  23965–23998. PMLR, 2022.
  41. Stable and low-precision training for large-scale vision-language models. arXiv preprint arXiv: Arxiv-2304.13013, 2023.
  42. Training trajectories of language models across scales. ARXIV.ORG, 2022. doi: 10.48550/arXiv.2212.09803.
  43. Swalp: Stochastic weight averaging in low precision training. In International Conference on Machine Learning, pp.  7015–7024. PMLR, 2019.
  44. Opt: Open pre-trained transformer language models. ARXIV.ORG, 2022.
Citations (14)

Summary

  • The paper introduces LAWA, a method leveraging early checkpoint averaging as a surrogate for learning rate decay to boost convergence in LLM pre-training.
  • The study shows that LAWA outperforms traditional techniques like EMA and SWA, delivering faster training on nanoGPT-2 and Pythia models.
  • The approach enhances zero-shot task performance and scalability, reducing computational costs while improving model generalization.

Introduction

The paper "Early Weight Averaging meets High Learning Rates for LLM Pre-training" presents a novel approach to improve the convergence and generalization of LLMs during pre-training by employing early checkpoint averaging. The methodology is particularly effective when combined with high learning rates, which are inherently preferred due to the large batch sizes involved in LLM pre-training. The paper evaluates various pre-training techniques on nanoGPT-2 and Pythia models, demonstrating substantial improvements over conventional training methods and popular averaging baselines.

Methodology

Optimization and Diversity Insight

The paper introduces a strategy leveraging checkpoint averaging early during training, proposing that this technique can act as a surrogate for learning rate decay. By averaging weights post hoc from high learning rate trajectories, the approach mitigates oscillations in sensitive weight dimensions while enhancing generalization efficiently.

Additionally, the paper incorporates insights from model ensembling literature, suggesting that averaging distant checkpoints in a training trajectory induces model diversity. Increased diversity in checkpoints correlates with improved ensemble performance, thus enhancing the robustness and generalization capabilities of the final model.

LAWA Algorithm

The Latest Weight Averaging (LAWA) algorithm is adapted for LLM pre-training by integrating regular sampling of distant checkpoints and maintaining a sliding window of the latest checkpoints for averaging. The approach avoids restarting training with new schedulers and does not require intricate adaptations for batch normalization, simplifying integration into large-scale training regimes.

(Pseudocode for LAWA is provided in Algorithm 1, demonstrating its implementation in a Pytorch-style environment.)

Experimental Setup

The paper details the controlled experimental setup conducted across nanoGPT-2 models (125M, 335M, 770M parameters) and larger Pythia LLMs (up to 12B parameters). The models are trained on substantial datasets like OpenWebText and PILE-deduped, with evaluations conducted on held-out test sets and through zero-shot performance on downstream tasks such as Lambada OpenAI and SciQ.

Validation of the LAWA approach is characterized by analyzing log perplexity and zero-shot task performance, leveraging a range of checkpoints sampled throughout the pre-training trajectory.

(Figure 1, the paper's experimental results, show enhanced validation loss improvements with LAWA compared to EMA and SWA across different nanoGPT-2 model scales.)

Results

Early and Efficient Convergence

The experimental results affirm that models trained with higher learning rates exhibit substantial improvements and faster convergence when LAWA is applied. This gain is notably pronounced in early training stages and diminishes as the training progresses, due to inherent learning rate scheduling decays in traditional schemes.

The paper highlights that LAWA consistently outperforms traditional training paradigms and baseline averaging techniques, such as Exponential Moving Average (EMA) and Stochastic Weight Averaging (SWA). Notably, SWA applied earlier in training diverged, underscoring LAWA's versatility and robustness.

(Figure 2 demonstrates the performance trajectory of the nanoGPT-2 models using LAWA against original training, showcasing observable improvements in convergence rates and performance metrics.)

Scalability to Large Models

Extensive testing on Pythia models indicates that LAWA improves generalization across different scales, with accelerated convergence saving significant computational resources and training costs. The performance gains for larger models suggest the necessity of employing diverse checkpoint strategies to accommodate increased learning dynamics inherent to large-scale learning rates.

(Figure 3 highlights the substantial convergence savings for Pythia-2.8B and 6.9B models through LAWA-derived checkpoints compared to traditional training checkpoints.)

Improved Zero-Shot Performance

LAWA enhances zero-shot performance across a spectrum of downstream tasks, affirming the correlation between reduced training perplexity and improved zero-shot accuracy. These improvements are persistent across various checkpoints, offering significant computational advantages and early stopping opportunities in compute-constrained environments.

(Figure 4 illustrates the augmented zero-shot task performance with LAWA, providing empirical evidence of LAWA's effectiveness in optimizing zero-shot accuracy against traditional training checkpoints.)

Conclusion

The paper concludes that early weight averaging via LAWA offers substantial advantages for LLM pre-training, effectively accelerating convergence and improving generalization without increased computational overhead. The technique highlights several extensions for future research, including federated fine-tuning and continual intermediate checkpoint training, suggesting impactful implications for realizing efficient and scalable AI systems.

Dice Question Streamline Icon: https://streamlinehq.com

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets

This paper has been mentioned in 23 tweets and received 85 likes.

Upgrade to Pro to view all of the tweets about this paper:

Youtube Logo Streamline Icon: https://streamlinehq.com