diff --git a/docs/reward_modelling.qmd b/docs/reward_modelling.qmd index 386dc1f57..b5cf3010d 100644 --- a/docs/reward_modelling.qmd +++ b/docs/reward_modelling.qmd @@ -11,6 +11,7 @@ We support the reward modelling techniques supported by `trl`. ### (Outcome) Reward Models Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step). +For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)). ```yaml base_model: google/gemma-2-2b diff --git a/examples/qwen3/reward-model.yaml b/examples/qwen3/reward-model.yaml new file mode 100644 index 000000000..43c62ecc4 --- /dev/null +++ b/examples/qwen3/reward-model.yaml @@ -0,0 +1,44 @@ +base_model: Skywork/Skywork-Reward-V2-Qwen3-8B +model_type: AutoModelForSequenceClassification +num_labels: 1 + +reward_model: true +center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability +chat_template: qwen3 +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template + +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 8192 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +deepspeed: deepspeed_configs/zero1.json + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +eval_batch_size: 1 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: linear +learning_rate: 0.00002 + +bf16: true +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +warmup_ratio: 0.1 +logging_steps: 1 +weight_decay: 0.01 diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e5bc68c39..057d0ab5c 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -7,10 +7,7 @@ from pathlib import Path from typing import Type, Union import transformers -from transformers import ( - DataCollatorWithFlattening, - EarlyStoppingCallback, -) +from transformers import DataCollatorWithFlattening, EarlyStoppingCallback from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.builders.base import TrainerBuilderBase @@ -26,12 +23,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( - LossWatchDogCallback, - SaveBetterTransformerModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, colab_inference_post_train_callback, log_prediction_callback_factory, + LossWatchDogCallback, + SaveBetterTransformerModelCallback, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.qat import QATCallback @@ -340,6 +337,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig + if self.cfg.center_rewards_coefficient is not None: + training_arguments_kwargs["center_rewards_coefficient"] = ( + self.cfg.center_rewards_coefficient + ) elif self.cfg.process_reward_model: training_args_cls = AxolotlPRMConfig else: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1d2ddf4ae..32d7b68e7 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -138,6 +138,12 @@ class AxolotlInputConfig( "description": "Process reward modelling: `True` or `False`" }, ) + center_rewards_coefficient: float | None = Field( + default=None, + json_schema_extra={ + "description": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." + }, + ) num_labels: int | None = None # Whether to use weighting in DPO trainer. # If `None`, default is `False` in the trainer.