Emergent Mind

Transformers Get Stable: An End-to-End Signal Propagation Theory for Language Models

(2403.09635)
Published Mar 14, 2024 in cs.CL , cs.AI , cs.CV , and cs.LG

Abstract

In spite of their huge success, transformer models remain difficult to scale in depth. In this work, we develop a unified signal propagation theory and provide formulae that govern the moments of the forward and backward signal through the transformer model. Our framework can be used to understand and mitigate vanishing/exploding gradients, rank collapse, and instability associated with high attention scores. We also propose DeepScaleLM, an initialization and scaling scheme that conserves unit output/gradient moments throughout the model, enabling the training of very deep models with 100s of layers. We find that transformer models could be much deeper - our deep models with fewer parameters outperform shallow models in Language Modeling, Speech Translation, and Image Classification, across Encoder-only, Decoder-only and Encoder-Decoder variants, for both Pre-LN and Post-LN transformers, for multiple datasets and model sizes. These improvements also translate into improved performance on downstream Question Answering tasks and improved robustness for image classification.

DeepScaleLM conserves variance in both directions for ViT across 192 layers, using ImageNet data.

Overview

  • The paper develops a theory on signal propagation in transformers to understand and mitigate instabilities in deep models.

  • Identifies three main sources of instability: Vanishing/Exploding Gradients, Rank Collapse, and Instability from High Attention Scores.

  • Introduces DeepScaleLM, a novel approach for stabilizing deep transformer models through a unique initialization scheme and scaling of residuals.

  • Empirically validates the effectiveness of DeepScaleLM across multiple tasks, showing improved performance and stability in deeper models.

Understanding and Mitigating Instabilities in Deep Transformers

Signal Propagation through Transformers

The scaling of transformer models, especially in terms of depth, has been a critical area of research due to its direct influence on the models' ability to learn complex patterns and generalize well on unseen data. However, the challenge lies in dealing with the instability issues that arise as models go deeper. In this recent exploration, we develop a comprehensive theory on signal propagation in transformers, which sheds light on the underlying causes of such instabilities and proposes a novel scheme, DeepScaleLM, to address them effectively.

Key Findings on Instability Issues

The analysis reveals three main sources of instability in deep transformer models:

  1. Vanishing/Exploding Gradients: A significant concern where the gradients either grow exponentially or diminish as they backpropagate through layers, making the model difficult to train.
  2. Rank Collapse: It entails the diminishing rank of token representations, leading to a loss of information across layers.
  3. Instability from High Attention Scores: High QK (Query-Key) values can result in unstable training dynamics.

The work systematically dissects these issues by providing a unified formulaic framework that describes the forward and backward signal propagation through different components of the transformer model. This framework is pivotal in understanding how various factors, such as initialization schemes and component-wise operations, influence model stability.

DeepScaleLM: Preserving Signal Integrity in Deeper Models

DeepScaleLM emerges as a solution, rooted in the insights garnered from the theoretical analysis, to train very deep transformer models without succumbing to the aforementioned instabilities. Its core lies in a novel initialization scheme and careful scaling of the residual connections, ensuring that the signal (both forward and backward) retains its integrity across layers. The scheme can be succinctly described as follows:

  • Utilize unit scaling for residuals and outputs, ensuring the preservation of signal variance.
  • Adopt a layer-specific output scaling, which dynamically maintains the signal's variance as unitary throughout the model.
  • Implement a rigorous initialization protocol that tailors the variance of weights according to the depth, mitigating vanishing or exploding effects.

Empirical Validation and Future Prospects

The approach is rigorously validated across various tasks, modalities, and architectures, demonstrating its efficacy in stabilizing the training of deep transformer models. Notably, models trained under the DeepScaleLM scheme outperform their shallower counterparts, accomplishing higher accuracy with fewer parameters in tasks ranging from language modeling and speech translation to image classification.

Looking ahead, the potential of DeepScaleLM extends beyond just stabilizing deep transformers. It opens up avenues for exploring even deeper architectures, potentially unlocking new levels of performance across tasks and domains. Additionally, the theoretical framework provides a foundation for future research to further dissect and enhance our understanding of transformer dynamics, paving the way for more robust and efficient models.

In conclusion, the work provides valuable insights into the challenges of scaling transformers and introduces a practical solution to navigate these challenges effectively. As the quest for more powerful AI models continues, approaches like DeepScaleLM will be crucial in harnessing the full potential of deep learning architectures.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.