Center rewards coefficient (#3124)

* feat: add center_rewards_coefficient for reward modeling

- Add center_rewards_coefficient parameter to Pydantic schema with paper reference
- Pass parameter through base builder and causal builder to training args
- Add documentation section with usage examples and theoretical background
- Enable parameter in reward modeling example configs with recommended value
- Enables reward centering for improved training stability in RLHF workflows

Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244)
to incentivize mean-zero reward outputs without post-training normalization.

* Update description

* test: add unit tests for center_rewards_coefficient integration

* Update src/axolotl/core/builders/base.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* reference to TRL documentation.

* add new reward model configuration for qwen3 with comprehensive parameters

* Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments.

* Refactor reward modeling documentation to consolidate information on center_rewards_coefficient

* Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup.

* linting

* nit

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
yardenhoch
2025-09-03 23:22:37 +03:00
committed by GitHub
parent 48db520d92
commit efa1da52d5
4 changed files with 58 additions and 6 deletions

View File

@@ -11,6 +11,7 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml
base_model: google/gemma-2-2b

View File

@@ -0,0 +1,44 @@
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
chat_template: qwen3
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
deepspeed: deepspeed_configs/zero1.json
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: linear
learning_rate: 0.00002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_ratio: 0.1
logging_steps: 1
weight_decay: 0.01

View File

@@ -7,10 +7,7 @@ from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
@@ -26,12 +23,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
@@ -340,6 +337,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
if self.cfg.center_rewards_coefficient is not None:
training_arguments_kwargs["center_rewards_coefficient"] = (
self.cfg.center_rewards_coefficient
)
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:

View File

@@ -138,6 +138,12 @@ class AxolotlInputConfig(
"description": "Process reward modelling: `True` or `False`"
},
)
center_rewards_coefficient: float | None = Field(
default=None,
json_schema_extra={
"description": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
},
)
num_labels: int | None = None
# Whether to use weighting in DPO trainer.
# If `None`, default is `False` in the trainer.