Emergent Mind

ReFT: Reasoning with Reinforced Fine-Tuning

(2401.08967)
Published Jan 17, 2024 in cs.CL

Abstract

One way to enhance the reasoning capability of LLMs is to conduct Supervised Fine-Tuning (SFT) using Chain-of-Thought (CoT) annotations. This approach does not show sufficiently strong generalization ability, however, because the training only relies on the given CoT data. In math problem-solving, for example, there is usually only one annotated reasoning path for each question in the training data. Intuitively, it would be better for the algorithm to learn from multiple annotated reasoning paths given a question. To address this issue, we propose a simple yet effective approach called Reinforced Fine-Tuning (ReFT) to enhance the generalizability of learning LLMs for reasoning, with math problem-solving as an example. ReFT first warmups the model with SFT, and then employs on-line reinforcement learning, specifically the PPO algorithm in this paper, to further fine-tune the model, where an abundance of reasoning paths are automatically sampled given the question and the rewards are naturally derived from the ground-truth answers. Extensive experiments on GSM8K, MathQA, and SVAMP datasets show that ReFT significantly outperforms SFT, and the performance can be potentially further boosted by combining inference-time strategies such as majority voting and re-ranking. Note that ReFT obtains the improvement by learning from the same training questions as SFT, without relying on extra or augmented training questions. This indicates a superior generalization ability for ReFT.

Comparison of SFT and ReFT in identifying alternatives within a CoT context.

Overview

  • LLMs are enhanced by Reinforced Fine-Tuning (ReFT), which improves their reasoning and math-solving abilities over the traditional Supervised Fine-Tuning (SFT) method.

  • ReFT introduces a reinforcement learning stage using the Proximal Policy Optimization algorithm to expose the model to multiple reasoning paths, enhancing generalization.

  • The ReFT method validates its superiority over SFT by performing better on math problem-solving tasks across various datasets without requiring extra training data.

  • Techniques like majority voting and re-ranking further improve ReFT's performance during inference time.

  • The ReFT approach is resource-efficient and its codebase is shared publicly, promising advancements in reasoning capabilities of LLMs.

Introduction to Reinforced Fine-Tuning (ReFT)

LLMs are powerful tools that have revolutionized the way we approach complex problem-solving, particularly in the realm of natural language processing. These models, trained on vast datasets, are capable of generating text, answering questions, and in some cases, even performing arithmetic. However, their ability to reason and solve math problems effectively can still be improved. Typically, these models are refined using a method known as Supervised Fine-Tuning (SFT), wherein a model is trained to follow reasoning steps, known as Chain-of-Thought (CoT), annotated in the training data. This training often relies on the assumption that there's only one right way to solve a problem, which may not always be ideal for generalization across diverse problem types.

Rationale Behind Reinforced Fine-Tuning (ReFT)

Researchers have recognized the limitations of SFT, pointing out its potential inadequacy in generalizing from the CoT annotations it was trained on. Ideally, a model would benefit from exposure to multiple reasoning paths for the same problem during its training. Addressing this need, the team at ByteDance Research has proposed an innovative approach called Reinforced Fine-Tuning (ReFT). This method goes beyond the traditional SFT by incorporating a reinforcement learning stage after the initial SFT process. During this stage, the model employs the Proximal Policy Optimization (PPO) algorithm to explore a variety of reasoning paths and learn from them. The ReFT method leverages ground-truth answers from the training data to naturally derive rewards, eliminating the need for extra training questions or augmented datasets. This leads to improved generalization capabilities for solving math problems.

Validation through Experiments

The proposed ReFT method was thoroughly tested against SFT using math problem-solving as the benchmark task. Experiments were conducted on reputable datasets like GSM8K, MathQA, and SVAMP. These datasets encompass a variety of question types, with answers ranging from numeric to multiple-choice formats. The results were stark, showing that ReFT significantly outperformed SFT, solidifying its enhanced ability to generalize. Additionally, techniques such as majority voting and re-ranking when applied during inference time could further enhance ReFT's performance.

Concluding Remarks on ReFT's Advantages

The novel ReFT approach not only supersedes SFT in performance but does so by learning from the same training questions, without additional resources. It demonstrates an outstanding capability to distill multiple annotated reasoning paths and profit from the vast data available. The outcome is a model with improved generalization skills that effectively copes with varied and complicated mathematical problems. The codebase for this research has been made publicly accessible, allowing for community collaboration and continued advancement in the modeling of reasoning for LLMs.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

HackerNews