Emergent Mind

Teaching Transformers Causal Reasoning through Axiomatic Training

(2407.07612)
Published Jul 10, 2024 in cs.LG , cs.AI , and cs.CL

Abstract

For text-based AI systems to interact in the real world, causal reasoning is an essential skill. Since interventional data is costly to generate, we study to what extent an agent can learn causal reasoning from passive data. Specifically, we consider an axiomatic training setup where an agent learns from multiple demonstrations of a causal axiom (or rule), rather than incorporating the axiom as an inductive bias or inferring it from data values. A key question is whether the agent would learn to generalize from the axiom demonstrations to new scenarios. For example, if a transformer model is trained on demonstrations of the causal transitivity axiom over small graphs, would it generalize to applying the transitivity axiom over large graphs? Our results, based on a novel axiomatic training scheme, indicate that such generalization is possible. We consider the task of inferring whether a variable causes another variable, given a causal graph structure. We find that a 67 million parameter transformer model, when trained on linear causal chains (along with some noisy variations) can generalize well to new kinds of graphs, including longer causal chains, causal chains with reversed order, and graphs with branching; even when it is not explicitly trained for such settings. Our model performs at par (or even better) than many larger language models such as GPT-4, Gemini Pro, and Phi-3. Overall, our axiomatic training framework provides a new paradigm of learning causal reasoning from passive data that can be used to learn arbitrary axioms, as long as sufficient demonstrations can be generated.

Evaluating transformer generalization from simple causal chains to complex, branched, and shuffled structures.

Overview

  • The paper evaluates the ability of transformer models to learn causal reasoning from passive data using a novel axiomatic training scheme.

  • Key methodological contributions include generating training data via causal axioms, employing different positional encoding strategies, and testing on complex evaluation datasets.

  • Results show that the proposed training method enables transformers to generalize to longer, shuffled, reversed, and branched causal sequences, outperforming baselines such as GPT-4 in some cases.

An Analysis of Axiomatic Training for Causal Reasoning in Transformers

Introduction

"Causal reasoning is a fundamental capability for AI systems to interact effectively in the real world. While interventional data is often costly to produce, passive data provides a less expensive alternative to train AI models for causal inference. The focus of the paper, "Teaching Transformers Causal Reasoning through Axiomatic Training", is to evaluate the extent to which an AI agent, specifically a transformer model, can learn causal reasoning skills from passive data. This is achieved through a novel axiomatic training scheme that teaches transformers causal axioms directly from symbolic demonstrations."

Methodology

The paper proposes an innovative approach in which transformers are trained using symbolic tuples representing causal axioms. The main methodological contributions include the design of a training framework where each data instance comprises a premise, hypothesis, and result (Yes or No). The key here is that the model learns causal reasoning principles directly from these demonstrative tuples without requiring interventional data.

Key Components:

  1. Synthetic Data Generation:
    • The training data is generated using causal axioms such as the transitivity axiom. For example, if X -> Y and Y -> Z, then X -> Z.
    • Variability in training data is introduced by employing different node names, graph topologies, and causal graphs of varying lengths.
  2. Positional Encoding Strategies:
    • The paper evaluates three types of positional encodings: No positional encoding (NoPE), sinusoidal positional encoding (SPE), and learnable positional encoding (LPE).
  3. Evaluation Datasets:
    • Several complex evaluation datasets are designed to test different aspects of generalization such as longer graphs, shuffled sequences, reversed sequences, and branched networks.

Results

Length Generalization

Transformers trained using the proposed axiomatic training approach showed impressive generalization capabilities to longer causal sequences that were not seen during training. Notably, the best results were achieved using models with NoPE, outperforming other baselines including larger models such as GPT-4.

Node Name Shift

The models also performed robustly when tested on sequences with longer node names than those seen during training, indicating that the transformer successfully learned the underlying causal relationships rather than memorizing specific tokens.

Order of Causal Sequences

Performance on shuffled and fully reversed sequences further demonstrated the effectiveness of the axiomatic training approach. The NoPE models showcased a remarkable capacity to generalize to these new configurations, in some cases even surpassing large-scale language models like GPT-4.

Branching

The evaluation on branched causal graphs, which represent more complex structures, revealed that the axiomatic approach could handle significant complexity, maintaining relatively high accuracy even for unseen, densely branched networks.

Implications and Future Work

The axiomatic training framework introduced in this paper presents a new paradigm for teaching transformers causal reasoning. By learning from symbolic data, transformers can grasp causal axioms that allow them to generalize to diverse downstream applications.

Theoretical Implications

This work contributes to the broader literature on causal learning from passive data by demonstrating that transformers can learn complex causal reasoning abilities from structured, synthetic data representing causal axioms. This suggests that similar approaches could be employed to train AI models on various logical reasoning tasks, thereby improving their reasoning capabilities without extensive manual intervention.

Practical Implications

The performance of the trained transformers, especially models like TS2 (NoPE), showed promise in causal reasoning, rivaling and sometimes surpassing powerful LLMs like GPT-4 in specific contexts. This indicates that axiomatic training could be an efficient strategy for developing robust AI systems capable of sophisticated reasoning without the extensive computational resources typically required.

Future Work

  • Extending the axiomatic training approach to a broader set of causal axioms beyond transitivity, such as d-separation or the Markov property, could further enhance the reasoning capabilities of transformers.
  • Applying this training strategy to other logical and deductive reasoning tasks to explore its generalizability beyond causal inference.
  • Investigating the theoretical underpinnings of why certain positional encoding strategies, notably NoPE, significantly enhance the model's generalization capabilities.

Conclusion

The paper demonstrates that transformers can effectively learn causal reasoning through axiomatic training. This method not only facilitates transformers to learn from passive data but also enables their generalization to more complex causal structures, achieving accuracy comparable to or better than existing LLMs on specialized tasks. The implications of this research suggest a promising direction for developing more efficient AI systems capable of advanced reasoning, with wide-ranging applications in AI development and beyond.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.