diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index c238cbbc3..f26ef8969 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] + # TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the + # config reflects this regardless of how the model was instantiated. + if ( + self.cfg.reward_model + and getattr(self.model.config, "num_labels", None) != 1 + ): + self.model.config.num_labels = 1 trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 50b30cd26..5e6657a78 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -253,6 +253,23 @@ class TrainingValidationMixin: data["pad_to_sequence_len"] = True return data + @model_validator(mode="before") + @classmethod + def set_reward_model_defaults(cls, data): + if data.get("reward_model"): + if data.get("num_labels") is None: + data["num_labels"] = 1 + if not (data.get("type_of_model") or data.get("model_type")): + data["model_type"] = "AutoModelForSequenceClassification" + + if data.get("process_reward_model"): + if data.get("num_labels") is None: + data["num_labels"] = 2 + if not (data.get("type_of_model") or data.get("model_type")): + data["model_type"] = "AutoModelForTokenClassification" + + return data + @model_validator(mode="before") @classmethod def check_gas_bsz(cls, data): diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index ea3c4e6c4..a241e8549 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder: "cfg_string", [ "sft_cfg", - # "rm_cfg", # TODO fix for num_labels = 2 vs 1 + "rm_cfg", "prm_cfg", ], ) diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 21299ed98..d22927940 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -277,6 +277,34 @@ class TestValidation(BaseValidation): new_cfg = validate_config(cfg) assert new_cfg.type_of_model == "AutoModelForCausalLM" + def test_reward_model_defaults(self, minimal_cfg): + cfg = ( + DictDefault( + { + "reward_model": True, + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg.num_labels == 1 + assert new_cfg.type_of_model == "AutoModelForSequenceClassification" + + def test_process_reward_model_defaults(self, minimal_cfg): + cfg = ( + DictDefault( + { + "process_reward_model": True, + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg.num_labels == 2 + assert new_cfg.type_of_model == "AutoModelForTokenClassification" + def test_model_revision_remap(self, minimal_cfg): cfg = ( DictDefault(