Emergent Mind

Progress measures for grokking via mechanistic interpretability

(2301.05217)
Published Jan 12, 2023 in cs.LG and cs.AI

Abstract

Neural networks often exhibit emergent behavior, where qualitatively new capabilities arise from scaling up the amount of parameters, training data, or training steps. One approach to understanding emergence is to find continuous \textit{progress measures} that underlie the seemingly discontinuous qualitative changes. We argue that progress measures can be found via mechanistic interpretability: reverse-engineering learned behaviors into their individual components. As a case study, we investigate the recently-discovered phenomenon of ``grokking'' exhibited by small transformers trained on modular addition tasks. We fully reverse engineer the algorithm learned by these networks, which uses discrete Fourier transforms and trigonometric identities to convert addition to rotation about a circle. We confirm the algorithm by analyzing the activations and weights and by performing ablations in Fourier space. Based on this understanding, we define progress measures that allow us to study the dynamics of training and split training into three continuous phases: memorization, circuit formation, and cleanup. Our results show that grokking, rather than being a sudden shift, arises from the gradual amplification of structured mechanisms encoded in the weights, followed by the later removal of memorizing components.

Training and test losses for a modular addition task with varying data fractions and their effects.

Overview

  • The paper uses mechanistic interpretability to understand emergent behaviors in neural networks, focusing on 'grokking' in small transformers trained on modular addition tasks.

  • The authors divide the training process into three phases: memorization, circuit formation, and cleanup, emphasizing the role of weight decay in promoting generalized solutions.

  • The study defines two novel progress measures, restricted loss and excluded loss, to track the model's evolution towards generalization and explores the implications for future research on larger models.

Mechanistic Interpretability and Emergence in Neural Networks

Neural networks often exhibit emergent behaviors where qualitatively new capabilities arise as a result of scaling parameters, training data, or training steps. This paper presents an approach to understanding such emergent behaviors through mechanistic interpretability, focusing on the phenomenon of "grokking" observed in small transformers trained on modular addition tasks. The authors provide a comprehensive reverse engineering of the learned algorithm, confirming it via analysis of activations, weights, and Fourier space ablations.

The study investigates the dynamics of training and phases these into three continuous phases: memorization, circuit formation, and cleanup. They argue that grokking results from the gradual amplification of structured mechanisms encoded in the weights, followed by the removal of memorizing components.

Detailed Analysis of Grokking

Grokking is defined as the abrupt transition of models to generalizing solutions after extensive training steps, even when models initially overfit. The authors specifically examine this phenomenon using a modular addition task where inputs (a, b \in {0, \ldots, P-1}) for a prime (P) are given to predict their sum (c) mod (P). Small transformers trained with weight decay are observed to exhibit grokking consistently. Through mechanistic interpretability, the authors reverse-engineer the algorithm to establish that these networks perform addition by converting the task into rotations on a circle, leveraging discrete Fourier transforms and trigonometric identities.

The principal findings are based on four lines of evidence which detail:

  1. Consistent Periodic Structures in Weights and Activations: The weights and activations exhibit a periodic structure, with the embedding matrix (WE) being sparse in the Fourier basis, focusing on key frequencies (wk).
  2. Mechanistic Evidence: The neuron-logit map (W_L) is well approximated by a combination of sine and cosine terms of key frequencies, verifying the model utilizes trigonometric identities.
  3. Approximation of Neuron Activations: Most neurons in the multi-layer perceptron (MLP) layers are well-approximated by degree-2 polynomials of sines and cosines of key frequencies.
  4. Faithful Component Ablations: Replacing components of the model with their approximations generally does not harm and sometimes even improves performance, validating the accuracy of the mechanistic model.

Progress Measures for Grokking

The authors utilize their mechanistic understanding to define two progress measures: restricted loss and excluded loss, tracking the model's evolution towards a generalized solution. These metrics improve continuously before grokking occurs and allow the understanding of the training dynamics.

  1. Restricted Loss: Measures performance when all but the critical frequencies are ablated.
  2. Excluded Loss: Measures performance when only the critical frequencies are ablated, differentiating memorization from generalization.

Phases of Training

The training process is divided into three distinct phases:

  1. Memorization Phase: The network memorizes training data without leveraging the key frequencies.
  2. Circuit Formation Phase: The network starts forming the Fourier multiplication circuit, aided by weight decay, showing continuous improvement in restricted loss.
  3. Cleanup Phase: Weight decay significantly reduces non-key frequency components, transitioning the network to a simplified form that generalizes well.

Implications and Future Work

The paper's findings have significant practical and theoretical implications. They not only elucidate how emergent behaviors and grokking manifest at a mechanistic level but also highlight the critical role of weight decay in promoting generalized solutions. For future work, the authors suggest scaling mechanistic interpretability to larger, more complex models, and defining task-independent progress measures. They also advocate for developing a theory to predict the timing of phase transitions in emergent behaviors.

Conclusion

In summary, this paper successfully demonstrates the use of mechanistic interpretability to uncover the underlying dynamics of emergent behavior in neural networks. Through a detailed case study on small transformers trained for modular addition tasks, the authors provide clear evidence of the structured mechanisms that lead to grokking. This approach offers a promising direction for understanding and potentially predicting emergent behaviors in more complex and larger-scale models.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube