fix num_labels= 1 test fail (#3493) [skip ci]
* trl_num_lables=1 * casual num_lables=1,rwd model * lint
This commit is contained in:
@@ -277,6 +277,34 @@ class TestValidation(BaseValidation):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.type_of_model == "AutoModelForCausalLM"
|
||||
|
||||
def test_reward_model_defaults(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"reward_model": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.num_labels == 1
|
||||
assert new_cfg.type_of_model == "AutoModelForSequenceClassification"
|
||||
|
||||
def test_process_reward_model_defaults(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"process_reward_model": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.num_labels == 2
|
||||
assert new_cfg.type_of_model == "AutoModelForTokenClassification"
|
||||
|
||||
def test_model_revision_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user