fix num_labels= 1 test fail (#3493) [skip ci]

* trl_num_lables=1

* casual num_lables=1,rwd model

* lint
This commit is contained in:
VED
2026-03-20 14:42:23 +05:30
committed by GitHub
parent 1fc86d5295
commit 7920fe74ec
4 changed files with 53 additions and 1 deletions

View File

@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["dataset_tags"] = [ trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
] ]
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
# config reflects this regardless of how the model was instantiated.
if (
self.cfg.reward_model
and getattr(self.model.config, "num_labels", None) != 1
):
self.model.config.num_labels = 1
trainer = trainer_cls( trainer = trainer_cls(
model=self.model, model=self.model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,

View File

@@ -253,6 +253,23 @@ class TrainingValidationMixin:
data["pad_to_sequence_len"] = True data["pad_to_sequence_len"] = True
return data return data
@model_validator(mode="before")
@classmethod
def set_reward_model_defaults(cls, data):
if data.get("reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 1
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForSequenceClassification"
if data.get("process_reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 2
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForTokenClassification"
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_gas_bsz(cls, data): def check_gas_bsz(cls, data):

View File

@@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder:
"cfg_string", "cfg_string",
[ [
"sft_cfg", "sft_cfg",
# "rm_cfg", # TODO fix for num_labels = 2 vs 1 "rm_cfg",
"prm_cfg", "prm_cfg",
], ],
) )

View File

@@ -277,6 +277,34 @@ class TestValidation(BaseValidation):
new_cfg = validate_config(cfg) new_cfg = validate_config(cfg)
assert new_cfg.type_of_model == "AutoModelForCausalLM" assert new_cfg.type_of_model == "AutoModelForCausalLM"
def test_reward_model_defaults(self, minimal_cfg):
cfg = (
DictDefault(
{
"reward_model": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.num_labels == 1
assert new_cfg.type_of_model == "AutoModelForSequenceClassification"
def test_process_reward_model_defaults(self, minimal_cfg):
cfg = (
DictDefault(
{
"process_reward_model": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.num_labels == 2
assert new_cfg.type_of_model == "AutoModelForTokenClassification"
def test_model_revision_remap(self, minimal_cfg): def test_model_revision_remap(self, minimal_cfg):
cfg = ( cfg = (
DictDefault( DictDefault(