Connect with us

Artificial Intelligence

Advancing AI Alignment with Human Values Through WARM




Weight Averaged Reward Models LLM

Alignment of AI Systems with Human Values

Artificial intelligence (AI) systems are becoming increasingly capable of assisting humans in complex tasks, from customer service chatbots to medical diagnosis algorithms. However, as these AI systems take on more responsibilities, it is crucial that they remain aligned with human values and preferences. One approach to achieve this is through a technique called reinforcement learning from human feedback (RLHF). In RLHF, an AI system, known as the policy, is rewarded or penalized based on human judgments of its behavior. The goal is for the policy to learn to maximize its rewards, and thus behave according to human preferences.

A core component of RLHF is the reward model (RM). The RM is responsible for evaluating the policy's actions and outputs, and returning a reward signal to guide the learning process. Designing a good RM is challenging, as human preferences can be complex, context-dependent, and even inconsistent across individuals. Recently, researchers from Google DeepMind proposed an innovative technique called Weight Averaged Reward Models (WARM) to improve RM design.

The Trouble with Reward Hacking

A major problem in RLHF is reward hacking. Reward hacking occurs when the policy finds loopholes to game the RM system to obtain high rewards without actually satisfying the intended objectives. For example, suppose the goal is to train a writing assistant AI to generate high-quality summaries. The RM might reward concise and informative summaries. The policy could then learn to exploit this by generating very short, uninformative summaries peppered with keywords that trick the RM.

Reward hacking happens for two main reasons:

  1. Distribution shift – The RM is trained on a limited dataset of human-labeled examples. When deployed, the policy's outputs may come from different distributions that the RM does not generalize well to.
  2. Noisy labels – Human labeling is imperfect, with inter-rater disagreements. The RM may latch onto spurious signals rather than robust indicators of quality.

Reward hacking leads to useless systems that fail to match human expectations. Worse still, it can result in AI behaviors that are biased or even dangerous if deployed carelessly.

The Rise of Model Merging

The surging interest in model merging strategies like Model Ratatouille is driven by the realization that bigger models, while powerful, can be inefficient and impractical. Training a 1 trillion parameter model requires prohibitive amounts of data, compute, time and cost. More crucially, such models tend to overfit to the training distribution, hampering their ability to generalize to diverse real-world scenarios.

Model merging provides an alternate route to unlock greater capabilities without uncontrolled scaling up. By reusing multiple specialized models trained on different distributions, tasks or objectives, model merging aims to enhance versatility and out-of-distribution robustness. The premise is that different models capture distinct predictive patterns that can complement each other when merged.

Recent results illustrate the promise of this concept. Models obtained via merging, despite having far fewer parameters, can match or even exceed the performance of giant models like GPT-3. For instance, a Model Ratatouille ensemble of just 7 mid-sized checkpoints attains state-of-the-art accuracy on high-dimensional textual entailment datasets, outperforming GPT-3.

The simplicity of merging by weight averaging is a huge bonus. Training multiple auxiliary models does demand extra resources. But crucially, the inference-time computation remains identical to a single model, since weights are condensed into one. This makes the method easily adaptable, without concerns of increased latency or memory costs.

Mechanisms Behind Model Merging

But what exactly enables these accuracy gains from merging models? Recent analysis offers some clues:

  • Mitigating Memorization: Each model sees different shuffled batches of the dataset during training. Averaging diminishes any instance-specific memorization, retaining only dataset-level generalizations.
  • Reducing Variance: Models trained independently have uncorrelated errors. Combining them averages out noise, improving calibration.
  • Regularization via Diversity: Varying auxiliary tasks force models to latch onto more generalizable features useful across distributions.
  • Increasing Robustness: Inconsistency in predictions signals uncertainty. Averaging moderates outlier judgments, enhancing reliability.

In essence, model merging counterbalances weaknesses of individual models to amplify their collective strengths. The merged representation captures the common underlying causal structures, ignoring incidental variations.

This conceptual foundation connects model merging to other popular techniques like ensembling and multi-task learning. All these methods leverage diversity across models or tasks to obtain versatile, uncertainty-aware systems. The simplicity and efficiency of weight averaging, however, gives model merging a unique edge for advancing real-world deployments.

Weight Averaged Reward Models

Alignment process with WARM

Alignment process with WARM

WARM innovatively employs a proxy reward model (RM), which is a weight average of multiple individual RMs, each fine-tuned from the same pre-trained LLM but with varying hyperparameters. This method enhances efficiency, reliability under distribution shifts, and robustness against inconsistent preferences. The study also shows that using WARM as the proxy RM, particularly with an increased number of averaged RMs, improves results and delays the onset of ‘reward hacking', a phenomenon where control rewards deteriorate over time.

Here's a high-level overview:

  1. Start with a base language model pretrained on a large corpus. Initialize multiple RMs by adding small task-specific layers on top.
  2. Fine-tune each RM separately on the human preference dataset, using different hyperparameters like learning rate for diversity.
  3. Average the weights of the finetuned RMs to obtain a single WARM ensemble.

The key insight is that weight averaging retains only the invariant information that is learned across all the diverse RMs. This reduces reliance on spurious signals, enhancing robustness. The ensemble also benefits from variance reduction, improving reliability despite distribution shifts.

As discussed previously, diversity across independently trained models is crucial for unlocking the full potential of model merging. But what are some concrete techniques to promote productive diversity?

The WARM paper explores a few clever ideas that could generalize more broadly:

Ordering Shuffles

A trivial but impactful approach is shuffling the order in which data points are seen by each model during training. Even this simple step de-correlates weights, reducing redundant memorization of patterns.

Hyperparameter Variations

Tweaking hyperparameters like learning rate and dropout probability for each run introduces useful diversity. Models converge differently, capturing distinct properties of the dataset.

Checkpoint Averaging – Baklava

The Baklava method initializes models for merging from different snapshots along the same pretraining trajectory. This relaxes constraints compared to model soups which mandate a shared start point. Relative to model ratatouille, Baklava avoids additional tasks. Overall, it strikes an effective accuracy-diversity balance.

fine-tuning multiple Reward Models

The process begins with a pre-trained Large Language Model (LLM) 𝜃_𝑝𝑡. From this model, various checkpoints {𝜃_𝑠 𝑓 𝑡_𝑖} are derived during a Supervised Fine-Tuning (SFT) run, each collected at different SFT training steps. These checkpoints are then used as initializations for fine-tuning multiple Reward Models (RMs) {𝜙𝑖} on a preference dataset. This fine-tuning aims to adapt the models to align better with human preferences. After fine-tuning, these RMs are combined through a process of weight averaging, resulting in the final model, 𝜙_WARM.

Analysis confirms that adding older checkpoints by moving average harms individiual performance, compromising diversity merits. Averaging only the final representations from each run performs better. In general, balancing diversity goals with accuracy maintenance remains an open research challenge.

Overall, model merging aligns well with the general ethos in the field to recycle existing resources effectively for enhanced reliability, efficiency and versatility. The simplicity of weight averaging solidifies its position as a leading candidate for assembling robust models from readily available building blocks.

Unlike traditional ensembling methods that average predictions, WARM keeps computational overhead minimal by maintaining just a single set of weights. Experiments on text summarization tasks demonstrate WARM's effectiveness:

  • For best-of-N sampling, WARM attain 92.5% win rate against random selection according to human preference labels.
  • In RLHF, a WARM policy reaches 79.4% win rate against a policy trained with a single RM after same number of steps.
  • WARM continues to perform well even when a quarter of the human labels are corrupted.

These results illustrate WARM's potential as a practical technique for developing real-world AI assistants that behave reliably. By smoothing out inconsistencies in human feedback, WARM policies can remain robustly aligned with human values even as they continue learning from new experiences.

The Bigger Picture

WARM sits at the intersection of two key trends in AI alignment research. First is the study of out-of-distribution (OOD) generalization, which aims to enhance model performance on new data that differs from the training distribution. Second is research on algorithmic robustness, focusing on reliability despite small input perturbations or noise.

By drawing connections between these fields around the notion of learned invariances, WARM moves us toward more rigorously grounded techniques for value alignment. The insights from WARM could generalize even beyond RLHF, providing lessons for wider machine learning systems that interact with the open world.

Of course, reward modeling is just one piece of the alignment puzzle. We still need progress on other challenges like reward specification, scalable oversight, and safe exploration. Combined with complementary techniques, WARM could accelerate the development of AI that sustainably promotes human prosperity. By collectively elucidating the principles that underlie robust alignment, researchers are charting the route to beneficial, ethical AI.

I have spent the past five years immersing myself in the fascinating world of Machine Learning and Deep Learning. My passion and expertise have led me to contribute to over 50 diverse software engineering projects, with a particular focus on AI/ML. My ongoing curiosity has also drawn me toward Natural Language Processing, a field I am eager to explore further.