fix num_labels= 1 test fail (#3493) [skip ci]
* trl_num_lables=1 * casual num_lables=1,rwd model * lint
This commit is contained in:
@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs["dataset_tags"] = [
|
trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
|
||||||
|
# config reflects this regardless of how the model was instantiated.
|
||||||
|
if (
|
||||||
|
self.cfg.reward_model
|
||||||
|
and getattr(self.model.config, "num_labels", None) != 1
|
||||||
|
):
|
||||||
|
self.model.config.num_labels = 1
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
|
|||||||
@@ -253,6 +253,23 @@ class TrainingValidationMixin:
|
|||||||
data["pad_to_sequence_len"] = True
|
data["pad_to_sequence_len"] = True
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_reward_model_defaults(cls, data):
|
||||||
|
if data.get("reward_model"):
|
||||||
|
if data.get("num_labels") is None:
|
||||||
|
data["num_labels"] = 1
|
||||||
|
if not (data.get("type_of_model") or data.get("model_type")):
|
||||||
|
data["model_type"] = "AutoModelForSequenceClassification"
|
||||||
|
|
||||||
|
if data.get("process_reward_model"):
|
||||||
|
if data.get("num_labels") is None:
|
||||||
|
data["num_labels"] = 2
|
||||||
|
if not (data.get("type_of_model") or data.get("model_type")):
|
||||||
|
data["model_type"] = "AutoModelForTokenClassification"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_gas_bsz(cls, data):
|
def check_gas_bsz(cls, data):
|
||||||
|
|||||||
@@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder:
|
|||||||
"cfg_string",
|
"cfg_string",
|
||||||
[
|
[
|
||||||
"sft_cfg",
|
"sft_cfg",
|
||||||
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
|
"rm_cfg",
|
||||||
"prm_cfg",
|
"prm_cfg",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -277,6 +277,34 @@ class TestValidation(BaseValidation):
|
|||||||
new_cfg = validate_config(cfg)
|
new_cfg = validate_config(cfg)
|
||||||
assert new_cfg.type_of_model == "AutoModelForCausalLM"
|
assert new_cfg.type_of_model == "AutoModelForCausalLM"
|
||||||
|
|
||||||
|
def test_reward_model_defaults(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"reward_model": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.num_labels == 1
|
||||||
|
assert new_cfg.type_of_model == "AutoModelForSequenceClassification"
|
||||||
|
|
||||||
|
def test_process_reward_model_defaults(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"process_reward_model": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.num_labels == 2
|
||||||
|
assert new_cfg.type_of_model == "AutoModelForTokenClassification"
|
||||||
|
|
||||||
def test_model_revision_remap(self, minimal_cfg):
|
def test_model_revision_remap(self, minimal_cfg):
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user