Center rewards coefficient (#3124)
* feat: add center_rewards_coefficient for reward modeling - Add center_rewards_coefficient parameter to Pydantic schema with paper reference - Pass parameter through base builder and causal builder to training args - Add documentation section with usage examples and theoretical background - Enable parameter in reward modeling example configs with recommended value - Enables reward centering for improved training stability in RLHF workflows Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244) to incentivize mean-zero reward outputs without post-training normalization. * Update description * test: add unit tests for center_rewards_coefficient integration * Update src/axolotl/core/builders/base.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update docs/reward_modelling.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update docs/reward_modelling.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * reference to TRL documentation. * add new reward model configuration for qwen3 with comprehensive parameters * Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments. * Refactor reward modeling documentation to consolidate information on center_rewards_coefficient * Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup. * linting * nit * Apply suggestions from code review Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * lint --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -7,10 +7,7 @@ from pathlib import Path
|
||||
from typing import Type, Union
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
)
|
||||
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
@@ -26,12 +23,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
log_prediction_callback_factory,
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
@@ -340,6 +337,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
if self.cfg.center_rewards_coefficient is not None:
|
||||
training_arguments_kwargs["center_rewards_coefficient"] = (
|
||||
self.cfg.center_rewards_coefficient
|
||||
)
|
||||
elif self.cfg.process_reward_model:
|
||||
training_args_cls = AxolotlPRMConfig
|
||||
else:
|
||||
|
||||
@@ -138,6 +138,12 @@ class AxolotlInputConfig(
|
||||
"description": "Process reward modelling: `True` or `False`"
|
||||
},
|
||||
)
|
||||
center_rewards_coefficient: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
|
||||
},
|
||||
)
|
||||
num_labels: int | None = None
|
||||
# Whether to use weighting in DPO trainer.
|
||||
# If `None`, default is `False` in the trainer.
|
||||
|
||||
Reference in New Issue
Block a user