Emergent Mind

Mechanics of Next Token Prediction with Self-Attention

(2403.08081)
Published Mar 12, 2024 in cs.LG , cs.AI , cs.CL , and math.OC

Abstract

Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: $\textit{What}$ $\textit{does}$ $\textit{a}$ $\textit{single}$ $\textit{self-attention}$ $\textit{layer}$ $\textit{learn}$ $\textit{from}$ $\textit{next-token}$ $\textit{prediction?}$ We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: $\textbf{(1)}$ $\textbf{Hard}$ $\textbf{retrieval:}$ Given input sequence, self-attention precisely selects the $\textit{high-priority}$ $\textit{input}$ $\textit{tokens}$ associated with the last input token. $\textbf{(2)}$ $\textbf{Soft}$ $\textbf{composition:}$ It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.

Next-token prediction via hard retrieval of high-priority tokens using a 1-layer self-attention model.

Overview

  • The paper investigates the training mechanism of a single-layer self-attention model for next-token prediction, emphasizing a dual-phase process of hard retrieval and soft composition.

  • It introduces the concept of strongly connected components (SCCs) within token-priority graphs, leveraging gradient descent to prioritize token selection during training.

  • Theoretical and experimental results demonstrate global convergence of gradient descent to a support vector machine (SVM) solution, validating the feasibility of these methods for a wide range of vocabulary sizes.

Mechanics of Next Token Prediction with Self-Attention

The study, "Mechanics of Next Token Prediction with Self-Attention," investigates the inner workings of a single-layer self-attention model when tasked with next-token prediction using gradient descent optimization. The analysis elucidates how this fundamental building block of Transformer-based language models acquires the ability to generate the next token effectively.

Core Findings The authors introduce a dual-phase mechanism inherent in self-attention training:

  1. Hard Retrieval: This involves the precise selection of high-priority tokens related to the last input token.
  2. Soft Composition: This phase constructs a convex combination of the retrieved tokens to sample the next token.

The researchers introduce and formalize the concept of strongly connected components (SCCs) within directed token-priority graphs (TPGs). TPGs encapsulate the relationships within training data. The gradient descent process implicitly identifies SCCs, thereby guiding self-attention to prioritize tokens from the highest-priority SCC within the context window.

Mathematical Formulation The problem formulation hinges on analyzing a single-layer self-attention model trained with gradient descent for next-token prediction. Here are the salient steps:

  • Input Representation: The input sequence and the corresponding token embeddings are represented in matrix form.
  • Empirical Risk Minimization (ERM): The optimization problem aims to minimize the loss defined as the negative log-likelihood of correctly predicting the next token.

The solution to the problem is framed as finding weights $W$ such that:

  • Token correlations adhere to the constraints imposed by TPGs:
  • Enhancing priority correlation for $(i \Rightarrow j)$.
  • Neutralizing the correlation for $(i \asymp j)$.

The researchers derive the weight update mechanisms and utilize the convergence properties of gradient descent in this setting.

Key Results

  • Global Convergence: For log-loss and under certain assumptions, they prove that gradient descent converges globally. This entails the attention weights evolving directionally towards the SVM solution.
  • Feasibility of SVM: They establish that, provided the embedding matrix is full rank, the SVM problem associated with the token-priority graphs is feasible. This is significant because it ensures practical applicability in a wide range of scenarios where the vocabulary size does not exceed the embedding dimension.

The theoretical results are validated with experimental findings, showing that the gradient descent process indeed leads the attention weights to align with the SVM formulation, even for larger vocabulary scenarios where $K > d$.

Implications and Future Work This study sheds light on the implicit biases of the self-attention mechanism in Transformer-based models, particularly in the context of next-token prediction. The authors' insights into the SCC-based token-priority mechanism pave the way for more precise characterization and potentially even the design of more efficient attention mechanisms.

Future research can expand on several fronts:

  • Multi-layer and Multi-head Extensions: Extending the analysis to multi-layer, multi-head architectures typical of full-fledged Transformers.
  • Impact of MLP Layers: Investigating how the feed-forward layers following self-attention in Transformers contribute to the observed token selection and composition mechanisms.
  • Relaxation of Assumptions: Analyzing more complex settings without the convexity assumptions to generalize the convergence results further.

By distilling the mechanics of next-token prediction at the level of a single-layer self-attention model, this paper lays a robust theoretical foundation from which to explore and optimize more complex architectures in natural language processing and beyond.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.