fix num_labels= 1 test fail (#3493) [skip ci]

* trl_num_lables=1

* casual num_lables=1,rwd model

* lint
This commit is contained in:
VED
2026-03-20 14:42:23 +05:30
committed by GitHub
parent 1fc86d5295
commit 7920fe74ec
4 changed files with 53 additions and 1 deletions

View File

@@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder:
"cfg_string",
[
"sft_cfg",
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
"rm_cfg",
"prm_cfg",
],
)

View File

@@ -277,6 +277,34 @@ class TestValidation(BaseValidation):
new_cfg = validate_config(cfg)
assert new_cfg.type_of_model == "AutoModelForCausalLM"
def test_reward_model_defaults(self, minimal_cfg):
cfg = (
DictDefault(
{
"reward_model": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.num_labels == 1
assert new_cfg.type_of_model == "AutoModelForSequenceClassification"
def test_process_reward_model_defaults(self, minimal_cfg):
cfg = (
DictDefault(
{
"process_reward_model": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.num_labels == 2
assert new_cfg.type_of_model == "AutoModelForTokenClassification"
def test_model_revision_remap(self, minimal_cfg):
cfg = (
DictDefault(