Center rewards coefficient (#3124)

* feat: add center_rewards_coefficient for reward modeling

- Add center_rewards_coefficient parameter to Pydantic schema with paper reference
- Pass parameter through base builder and causal builder to training args
- Add documentation section with usage examples and theoretical background
- Enable parameter in reward modeling example configs with recommended value
- Enables reward centering for improved training stability in RLHF workflows

Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244)
to incentivize mean-zero reward outputs without post-training normalization.

* Update description

* test: add unit tests for center_rewards_coefficient integration

* Update src/axolotl/core/builders/base.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* reference to TRL documentation.

* add new reward model configuration for qwen3 with comprehensive parameters

* Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments.

* Refactor reward modeling documentation to consolidate information on center_rewards_coefficient

* Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup.

* linting

* nit

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
yardenhoch
2025-09-03 23:22:37 +03:00
committed by GitHub
parent 48db520d92
commit efa1da52d5
4 changed files with 58 additions and 6 deletions

View File

@@ -7,10 +7,7 @@ from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
@@ -26,12 +23,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
@@ -340,6 +337,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
if self.cfg.center_rewards_coefficient is not None:
training_arguments_kwargs["center_rewards_coefficient"] = (
self.cfg.center_rewards_coefficient
)
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:

View File

@@ -138,6 +138,12 @@ class AxolotlInputConfig(
"description": "Process reward modelling: `True` or `False`"
},
)
center_rewards_coefficient: float | None = Field(
default=None,
json_schema_extra={
"description": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
},
)
num_labels: int | None = None
# Whether to use weighting in DPO trainer.
# If `None`, default is `False` in the trainer.