This commit is contained in:
sunny
2024-11-05 12:46:05 -05:00
parent 41d10278bf
commit 39af2a41a5
2 changed files with 5 additions and 3 deletions

View File

@@ -588,7 +588,9 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
dpo_use_weighting: Optional[bool] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore

View File

@@ -114,7 +114,7 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
# pylint: disable=duplicate-code
@@ -159,7 +159,7 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):