Emergent Mind

WARP: On the Benefits of Weight Averaged Rewarded Policies

(2406.16768)
Published Jun 24, 2024 in cs.LG and cs.AI

Abstract

Reinforcement learning from human feedback (RLHF) aligns LLMs by encouraging their generations to have high rewards, using a reward model trained on human preferences. To prevent the forgetting of pre-trained knowledge, RLHF usually incorporates a KL regularization; this forces the policy to remain close to its supervised fine-tuned initialization, though it hinders the reward optimization. To tackle the trade-off between KL and reward, in this paper we introduce a novel alignment strategy named Weight Averaged Rewarded Policies (WARP). WARP merges policies in the weight space at three distinct stages. First, it uses the exponential moving average of the policy as a dynamic anchor in the KL regularization. Second, it applies spherical interpolation to merge independently fine-tuned policies into a new enhanced one. Third, it linearly interpolates between this merged model and the initialization, to recover features from pre-training. This procedure is then applied iteratively, with each iteration's final model used as an advanced initialization for the next, progressively refining the KL-reward Pareto front, achieving superior rewards at fixed KL. Experiments with GEMMA policies validate that WARP improves their quality and alignment, outperforming other open-source LLMs.

Iterative WARP length bias.

Overview

  • The paper introduces the Weight Averaged Rewarded Policies (WARP) method to improve reinforcement learning from human feedback (RLHF) for aligning LLMs by strategically balancing reward maximization and Kullback-Leibler (KL) regularization.

  • WARP employs Exponential Moving Average (EMA) for dynamic anchoring, Spherical Linear Interpolation (SLERP) to merge policies, and Linear Interpolation Towards Initialization (LITI) to balance task-specific capabilities with general pre-trained knowledge, thus progressively refining policy performance.

  • Empirical results validate WARP's efficacy in achieving better KL-reward Pareto optimality, outperforming benchmarks like Mixtral, and demonstrating notable improvements in text generation quality and alignment through iterative application of the method.

Analyzing "WARP: On the Benefits of Weight Averaged Rewarded Policies"

The paper entitled "WARP: On the Benefits of Weight Averaged Rewarded Policies" explore a novel strategy for reinforcement learning from human feedback (RLHF) to align LLMs. The key innovation introduced is the Weight Averaged Rewarded Policies (WARP) method, aimed at optimizing the trade-off between reward maximization and Kullback-Leibler (KL) regularization. This is achieved through a series of weight averaging steps that enhance learning efficiency and overall policy performance. The paper's methodical approach integrates exponential moving averages (EMA), spherical interpolation (SLERP), and linear interpolation towards initialization (LITI) to iteratively refine the policy weights, demonstrating improved alignment without the detrimental forgetting of pre-trained knowledge.

Problem Context and Challenges

Aligning LLMs like GPT-4 and Gemini with human values is critical for safe AI deployment. The challenge in RLHF lies in the tendency of models to forget pre-trained knowledge when fine-tuned excessively on a limited, specialized dataset, leading to reward hacking and reduced diversity of generated outputs. KL regularization is often employed to mitigate this, anchoring the policy close to its initialization to retain generality, albeit at the cost of reduced reward optimization. This inherent tension between KL regularization and reward maximization forms the crux of the difficulty in aligning LLMs.

Proposed Solution: WARP

WARP introduces a three-pronged approach to address the above challenges:

  1. Exponential Moving Average (EMA) as Dynamic Anchor: WARP first utilizes the EMA of the policy weights as a dynamic anchor in the KL regularization during reinforcement learning. Unlike a static supervised fine-tuned (SFT) initialization, the dynamic nature of EMA allows for an automatic annealing effect. This gradual relaxation of KL constraints facilitates higher reward acquisition while maintaining alignment. EMA also serves as a mean teacher for the policy, improving stability and distillation.

  2. Spherical Linear Interpolation (SLERP): Next, WARP applies SLERP to merge multiple independently fine-tuned policies. This method ensures that the combined model benefits from the strengths of individual policies, achieving higher rewards than any single policy alone. SLERP maintains the norms of task vectors, preventing performance degradation which often accompanies linear interpolation methods.

  3. Linear Interpolation Towards Initialization (LITI): Finally, WARP employs LITI to interpolate the SLERP-merged policy weights back towards the SFT initialization. This interpolation balances between the newly acquired task-specific capabilities and the retained general pre-trained knowledge, revealing a more optimal Pareto front in KL-reward space. LITI ensures effective initialization for the next iterative round, promoting progressive improvement.

Empirical Results

Experiments conducted with Gemma policies validate WARP's efficacy. The results demonstrate that policies fine-tuned with an EMA anchor achieve better KL-reward Pareto optimality than those with a static SFT anchor. The SLERP-based merging significantly enhances rewards by combining the best attributes of independent policies. Further, the iterative nature of WARP, using LITI for refined initialization, progressively boosts performance across iterations.

Concretely, the paper showcases how merging policies with SLERP from three training stages and iteratively applying WARP leads to superior alignment and rewards when compared to other open-source LLMs. The iterative application demonstrates an improvement in the quality of generated text and alignment, outperforming benchmarks like Mixtral and achieving notable scores on standard benchmarks such as MMLU and GSM8K.

Theoretical and Practical Implications

The theoretical underpinning of WARP lies in the effective use of task vectors and leveraging weight averaging methods to retain model generality while optimizing for task-specific rewards. The empirical results highlight WARP's capacity to handle the trade-off in RLHF, mitigating the issues of reward hacking and catastrophic forgetting.

On a broader scale, WARP aligns with the ideals of iterative amplification, presenting a scalable approach to progressively refining LLMs. The distributed learning aspects facilitate collaborative training, promoting open-source contributions and operationalizing LLMs at scale without inference or memory overheads.

Future Directions

Future research could explore the integration of WARP with advanced reward models or its applicability in federated learning scenarios. Additionally, examining the effects of diverse objectives within WARP, as well as merging generative and discriminative tasks, could open new avenues for enhancing model robustness and generalization.

In summary, "WARP: On the Benefits of Weight Averaged Rewarded Policies" presents a comprehensive and iterative approach to improving LLM alignment, offering significant advancements and practical benefits in the realm of reinforcement learning from human feedback.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube