Emergent Mind

Transformer Feed-Forward Layers Are Key-Value Memories

(2012.14913)
Published Dec 29, 2020 in cs.CL

Abstract

Feed-forward layers constitute two-thirds of a transformer model's parameters, yet their role in the network remains under-explored. We show that feed-forward layers in transformer-based language models operate as key-value memories, where each key correlates with textual patterns in the training examples, and each value induces a distribution over the output vocabulary. Our experiments show that the learned patterns are human-interpretable, and that lower layers tend to capture shallow patterns, while upper layers learn more semantic ones. The values complement the keys' input patterns by inducing output distributions that concentrate probability mass on tokens likely to appear immediately after each pattern, particularly in the upper layers. Finally, we demonstrate that the output of a feed-forward layer is a composition of its memories, which is subsequently refined throughout the model's layers via residual connections to produce the final output distribution.

Illustration shows feed-forward layer mimicking key-value memory, using input vectors and memory coefficients.

Overview

  • Feed-forward layers in transformer models act as key-value memories detecting patterns for output distribution.

  • These layers capture a stratification of complexity in input patterns, recognizing both shallow and semantic patterns.

  • The study demonstrates how values in these layers induce distributions over the output vocabulary, complementing detected input patterns.

  • The final model predictions result from an aggregation of memory contributions across all layers, refined through residual connections.

Unveiling the Function of Feed-Forward Layers in Transformer Models

Overview

In the domain of transformer models, feed-forward layers constitute a significant proportion of parameters, yet their specific role has not been thoroughly explored. This paper details how feed-forward layers simulate key-value memory systems, essentially acting as pattern detectors that influence the model’s output distribution. The analysis unveils that different layers capture various complexities of input patterns, from shallow to more semantic, and how these patterns are integrated to form the final prediction.

Feed-Forward Layers as Neural Memory

Feed-forward layers in transformers have been under-explored despite their predomination in the model’s architecture. This paper argues that these layers function akin to key-value memories, where the first matrix of parameters represents keys that detect specific patterns in input texts, and the second matrix encapsulates values that determine the distribution over the model's output vocabulary. This functional equivalence is highlighted through the formulation of feed-forward operations and neural memory, underscoring the principal role of feed-forward layers in pattern recognition across the input data.

Patterns Captured by Keys

An experimental investigation into what these keys represent reveals that each key is associated with distinct, human-interpretable patterns within the input texts. Lower layers tend to recognize shallow patterns, such as specific n-grams, while upper layers are adept at identifying more semantic patterns, indicating a stratification in the complexity of recognized patterns across the model. This stratification supports the concept of hierarchical processing in neural networks, where initial layers focus on low-level features and higher layers on more abstract concepts.

Values as Output Distributions

Moving to the role of values, the study shows that these can be viewed as inducing distributions over the output vocabulary that complements the input patterns detected by the keys, particularly in the model's upper layers. This relationship between keys and values grows more pronounced in higher layers, suggesting that as the model processes information hierarchically, the upper layers synthesize detected patterns to predict the next-token distribution more accurately.

Memory Aggregation and Model prediction

The paper explores how the transformer model leverages these individual memories across all layers to refine and derive the final output distribution. It demonstrates that model predictions are not solely reliant on dominant memory activations but result from the complex aggregation of multiple memory contributions, refined through residual connections across layers. This process illustrates a bottom-up assembly where detected patterns are incrementally merged and refined to form the model’s output.

Implications and Future Directions

The outlined findings bear significant implications for both theoretical understanding and practical application of transformer models. They provide a more nuanced comprehension of how feed-forward layers contribute to the model's ability to process and predict linguistic patterns. Practically, this insight could inspire more efficient model architectures and interpretability tools by focusing on the nuanced roles of these layers. Moreover, exploring how these findings translate beyond language models to other transformer-based applications represents an intriguing future direction.

Conclusion

In summary, by casting feed-forward layers as memory systems that recognize patterns and influence output distributions, this paper sheds light on the critical yet underscratched function of these layers within transformer models. This understanding not only enriches the current comprehension of transformer architectures but also opens new avenues for research in making these models more interpretable, efficient, and versatile.

Subscribe by Email

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.