Emergent Mind

Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs

(2406.18629)
Published Jun 26, 2024 in cs.LG , cs.AI , and cs.CL

Abstract

Mathematical reasoning presents a significant challenge for LLMs due to the extensive and precise chain of reasoning required for accuracy. Ensuring the correctness of each reasoning step is critical. To address this, we aim to enhance the robustness and factuality of LLMs by learning from human feedback. However, Direct Preference Optimization (DPO) has shown limited benefits for long-chain mathematical reasoning, as models employing DPO struggle to identify detailed errors in incorrect answers. This limitation stems from a lack of fine-grained process supervision. We propose a simple, effective, and data-efficient method called Step-DPO, which treats individual reasoning steps as units for preference optimization rather than evaluating answers holistically. Additionally, we have developed a data construction pipeline for Step-DPO, enabling the creation of a high-quality dataset containing 10K step-wise preference pairs. We also observe that in DPO, self-generated data is more effective than data generated by humans or GPT-4, due to the latter's out-of-distribution nature. Our findings demonstrate that as few as 10K preference data pairs and fewer than 500 Step-DPO training steps can yield a nearly 3% gain in accuracy on MATH for models with over 70B parameters. Notably, Step-DPO, when applied to Qwen2-72B-Instruct, achieves scores of 70.8% and 94.0% on the test sets of MATH and GSM8K, respectively, surpassing a series of closed-source models, including GPT-4-1106, Claude-3-Opus, and Gemini-1.5-Pro. Our code, data, and models are available at https://github.com/dvlab-research/Step-DPO.

Step-DPO's data construction process.

Overview

  • The paper introduces Step-DPO, a novel method that enhances the performance of LLMs in long-chain mathematical reasoning tasks by focusing on step-wise optimization rather than holistic answer preferences.

  • Step-DPO identifies and corrects errors in intermediate steps of the reasoning process to ensure higher overall accuracy, demonstrating substantial improvements in accuracy and efficiency over traditional Direct Preference Optimization (DPO) and other state-of-the-art models.

  • The method holds significant theoretical and practical implications, paving the way for further research and application in educational tools and automated reasoning systems, with the potential to improve AI analytics and decision-making systems.

Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs

The paper "Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs" addresses a critical challenge in the field of LLMs: enhancing the robustness and factuality of LLMs during long-chain mathematical reasoning tasks. The authors—Xin Lai, Zhuotao Tian, Yukang Chen, Senqiao Yang, Xiangru Peng, and Jiaya Jia—propose a novel method called Step-DPO, which significantly outperforms traditional Direct Preference Optimization (DPO) techniques.

Background and Motivation

LLMs have exhibited considerable prowess in various natural language processing tasks due to their autoregressive nature, enabling them to predict subsequent tokens based on preceding context. However, their performance in long-chain reasoning, especially in mathematical contexts, remains suboptimal. Mathematical reasoning requires precise, step-by-step processing, where any intermediate error can derail the final outcome. Traditional DPO, which optimizes based on holistic answer preferences, fails to provide sufficient granularity to locate and rectify these detailed errors effectively.

Step-DPO Framework

In response to the limitations of DPO, the authors propose Step-DPO, which shifts the preference optimization focus from complete answers to individual reasoning steps. This approach involves the following steps:

  1. Step-wise Formulation: Instead of handling final answers, Step-DPO parses the reasoning process into discrete steps and targets the first incorrect step. This fine-grained optimization ensures that errors are identified and corrected within the reasoning chain, thereby enhancing overall accuracy.
  2. Data Construction Pipeline: The authors detail a systematic pipeline to construct a high-quality dataset of preference pairs. This pipeline involves:
  • Error Collection: Gathering a set of problems and the corresponding model-inferred answers.
  • Step Localization: Identifying the first erroneous step in the generated answers.
  • Rectification: Generating the correct next step using the model, ensuring that the data remains in-distribution, which avoids the pitfalls associated with out-of-distribution data.

Experimental Results

The Step-DPO method was evaluated rigorously, with the results demonstrating substantial improvements over traditional DPO and existing state-of-the-art models. Notably:

  • Quantitative Gains: Step-DPO, when applied to Qwen2-72B-Instruct, achieved 70.8% and 94.0% accuracy on the MATH and GSM8K test sets, respectively. This surpasses prominent closed-source models, including GPT-4-1106, Claude-3-Opus, and Gemini-1.5-Pro.
  • Efficiency: Significant performance gains were achieved with minimal training steps, underscoring the data efficiency of the Step-DPO method.
  • Robustness: The method proved effective even with a smaller dataset (10K preference pairs) and demonstrated strong generalization capabilities across different problem sets, including competition-level math problems.

Implications and Future Directions

The implications of this work are manifold:

  1. Theoretical Advancements: The Step-DPO method introduces a profound shift in the approach to preference optimization, emphasizing the necessity of fine-grained supervision in long-chain reasoning tasks. This paves the way for further research into step-wise optimizations in other domains beyond mathematics.
  2. Practical Applications: Practically, this method can enhance the reliability of LLMs deployed in educational tools, automated theorem proving, and other applications requiring robust sequential reasoning. By minimizing error propagation in complex reasoning tasks, Step-DPO holds potential for significantly improving AI-driven analytics and decision-making systems.
  3. Future Research: Future developments might focus on refining the data construction pipeline, perhaps incorporating more sophisticated error localization techniques or extending the approach to encompass other forms of long-chain reasoning tasks. Exploration into hybrid models that synergize Step-DPO with advanced RLHF techniques could yield further enhancements in LLM alignment and accuracy.

In summary, the Step-DPO method represents a significant advancement in the field of LLMs, providing an effective solution for improving the accuracy and reliability of long-chain mathematical reasoning. Through meticulous data construction and step-wise optimization, Step-DPO sets a new standard for preference optimization in complex reasoning tasks.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.