Center rewards coefficient (#3124)
* feat: add center_rewards_coefficient for reward modeling - Add center_rewards_coefficient parameter to Pydantic schema with paper reference - Pass parameter through base builder and causal builder to training args - Add documentation section with usage examples and theoretical background - Enable parameter in reward modeling example configs with recommended value - Enables reward centering for improved training stability in RLHF workflows Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244) to incentivize mean-zero reward outputs without post-training normalization. * Update description * test: add unit tests for center_rewards_coefficient integration * Update src/axolotl/core/builders/base.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update docs/reward_modelling.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update docs/reward_modelling.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * reference to TRL documentation. * add new reward model configuration for qwen3 with comprehensive parameters * Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments. * Refactor reward modeling documentation to consolidate information on center_rewards_coefficient * Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup. * linting * nit * Apply suggestions from code review Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * lint --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -11,6 +11,7 @@ We support the reward modelling techniques supported by `trl`.
|
||||
### (Outcome) Reward Models
|
||||
|
||||
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
|
||||
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-2-2b
|
||||
|
||||
44
examples/qwen3/reward-model.yaml
Normal file
44
examples/qwen3/reward-model.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
|
||||
model_type: AutoModelForSequenceClassification
|
||||
num_labels: 1
|
||||
|
||||
reward_model: true
|
||||
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
|
||||
chat_template: qwen3
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
deepspeed: deepspeed_configs/zero1.json
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
eval_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: linear
|
||||
learning_rate: 0.00002
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
warmup_ratio: 0.1
|
||||
logging_steps: 1
|
||||
weight_decay: 0.01
|
||||
@@ -7,10 +7,7 @@ from pathlib import Path
|
||||
from typing import Type, Union
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
)
|
||||
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
@@ -26,12 +23,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
log_prediction_callback_factory,
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
@@ -340,6 +337,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
if self.cfg.center_rewards_coefficient is not None:
|
||||
training_arguments_kwargs["center_rewards_coefficient"] = (
|
||||
self.cfg.center_rewards_coefficient
|
||||
)
|
||||
elif self.cfg.process_reward_model:
|
||||
training_args_cls = AxolotlPRMConfig
|
||||
else:
|
||||
|
||||
@@ -138,6 +138,12 @@ class AxolotlInputConfig(
|
||||
"description": "Process reward modelling: `True` or `False`"
|
||||
},
|
||||
)
|
||||
center_rewards_coefficient: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
|
||||
},
|
||||
)
|
||||
num_labels: int | None = None
|
||||
# Whether to use weighting in DPO trainer.
|
||||
# If `None`, default is `False` in the trainer.
|
||||
|
||||
Reference in New Issue
Block a user