Emergent Mind

Recursive Introspection: Teaching Language Model Agents How to Self-Improve

(2407.18219)
Published Jul 25, 2024 in cs.LG , cs.AI , and cs.CL

Abstract

A central piece in enabling intelligent agentic behavior in foundation models is to make them capable of introspecting upon their behavior, reasoning, and correcting their mistakes as more computation or interaction is available. Even the strongest proprietary LLMs do not quite exhibit the ability of continually improving their responses sequentially, even in scenarios where they are explicitly told that they are making a mistake. In this paper, we develop RISE: Recursive IntroSpEction, an approach for fine-tuning LLMs to introduce this capability, despite prior work hypothesizing that this capability may not be possible to attain. Our approach prescribes an iterative fine-tuning procedure, which attempts to teach the model how to alter its response after having executed previously unsuccessful attempts to solve a hard test-time problem, with optionally additional environment feedback. RISE poses fine-tuning for a single-turn prompt as solving a multi-turn Markov decision process (MDP), where the initial state is the prompt. Inspired by principles in online imitation learning and reinforcement learning, we propose strategies for multi-turn data collection and training so as to imbue an LLM with the capability to recursively detect and correct its previous mistakes in subsequent iterations. Our experiments show that RISE enables Llama2, Llama3, and Mistral models to improve themselves with more turns on math reasoning tasks, outperforming several single-turn strategies given an equal amount of inference-time computation. We also find that RISE scales well, often attaining larger benefits with more capable models. Our analysis shows that RISE makes meaningful improvements to responses to arrive at the correct solution for challenging prompts, without disrupting one-turn abilities as a result of expressing more complex distributions.

Two query methods for RISE model inference: with oracle for early termination or without oracle for majority voting.

Overview

  • The paper 'Recursive Introspection: Teaching Language Model Agents How to Self-Improve' introduces the RISE method to enhance LLMs by allowing them to improve their responses over multiple attempts using iterative fine-tuning informed by online imitation learning and reinforcement learning.

  • RISE employs a multi-turn Markov Decision Process (MDP) to convert single-turn interactions into continuous sequences, using strategies like on-policy data collection and reward-weighted training to focus on self-improvement based on past mistakes and high-quality responses.

  • Experimental evaluations on datasets GSM8K and MATH show that RISE significantly improves the performance of LLMs and scales effectively with model capability, underscoring its potential for practical applications in resource-constrained environments and future research in AI-driven decision-making.

Recursive Introspection: Teaching Language Model Agents How to Self-Improve

The paper "Recursive Introspection: Teaching Language Model Agents How to Self-Improve," authored by Yuxiao Qu, Tianjun Zhang, Naman Garg, and Aviral Kumar, introduces an innovative method, RISE (Recursive Introspection), which aims to enable LLMs to improve their responses over multiple sequential attempts. Current proprietary LLMs often fail to adapt and correct their mistakes when explicitly warned about their errors. This paper proposes an approach to overcome this limitation by employing an iterative fine-tuning process informed by principles of online imitation learning and reinforcement learning.

Summary of RISE Approach

RISE implements a multi-turn fine-tuning paradigm, converting single-turn prompt-response pairs into a Multi-turn Markov Decision Process (MDP). It leverages two primary strategies for iterative self-improvement: on-policy data collection and reward-weighted training. The process begins by unrolling the base model over several turns to capture the context and errors, followed by refining these responses either through distillation from more capable models or self-improvement based on on-policy generated samples.

  1. Data Collection for Self-Improvement: RISE iteratively fine-tunes the model using its previous responses interspersed with better responses from either more capable teacher models or through self-generated best-of-N responses. This iterative rollout ensures that the model is exposed to its own mistakes and learns to correct them over multiple attempts.
  2. Policy Improvement: The collected data is used to fine-tune the model by optimizing a reward-weighted regression objective. This objective prioritizes both high-quality and suboptimal responses, thus covering a wider state-action space and improving the model's self-improvement capability.
  3. Inference with and without Oracle: The proposed method can operate with an oracle for early termination or without one, using majority voting to finalize the response after multiple iterations, ensuring practical applicability even in resource-constrained environments.

Experimental Evaluation

The efficacy of RISE is evaluated on two datasets, GSM8K and MATH, demonstrating significant performance improvements. For instance, on GSM8K, RISE improved the performance of LLaMa3-8B by 8.2% and Mistral-7B by 6.6% over five turns, outperforming parallel sampling strategies. The experiments also highlighted that RISE scales well with more capable models, showing a 17.7% improvement for LLaMa2-7B and a 23.9% improvement for Mistral-7B over a five-turn introspection. Similar trends were observed on the MATH dataset.

The authors also performed an extensive ablation study, demonstrating when and why self-improvement over multiple turns is achievable. Notably, the study illustrated that fine-tuning with RISE generally yielded lower perplexity values, highlighting the advantage of conditioning on multi-turn data over single-turn data.

Theoretical and Practical Implications

The proposed RISE framework introduces a robust method for endowing LLMs with the ability to improve their responses iteratively. The approach provides several practical implications:

  • Enhanced Model Performance: By leveraging multi-turn fine-tuning and reward-weighted learning, RISE consistently improves the performance of LLMs in scenarios requiring complex reasoning and error correction.
  • Generalization to Out-of-Distribution Prompts: RISE demonstrates the potential for models to generalize self-improvement strategies to novel, unseen prompts, indicating robustness across different datasets.
  • Resource Efficiency: The ability to operate without an oracle for early termination makes RISE particularly suitable for deployment in resource-constrained environments, where querying the reward function is costly.
  • Foundation for Future Research: The findings open avenues for further research into fully online reinforcement learning methods and integration with general instruction-tuning pipelines to enhance the self-improvement capabilities of LLMs.

Conclusion

"Recursive Introspection: Teaching Language Model Agents How to Self-Improve" presents a compelling methodology for enhancing the adaptive capabilities of LLMs. By transforming single-turn problems into multi-turn MDPs and implementing an iterative fine-tuning approach, RISE enables models to learn from their mistakes and improve sequentially. This research underscores the potential of integrating imitation learning and reinforcement learning techniques to develop more intelligent, self-correcting language models, paving the way for advancements in AI-driven decision-making and reasoning tasks.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube