diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 2113fcd89..64fec67e0 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -588,7 +588,9 @@ class AxolotlInputConfig( rl: Optional[RLType] = None reward_model: Optional[bool] = None - dpo_use_weighting: Optional[bool] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + dpo_use_weighting: Optional[ + bool + ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index d909c5643..4a705922f 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -114,7 +114,7 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() - + @with_temp_dir def test_dpo_use_weighting(self, temp_dir): # pylint: disable=duplicate-code @@ -159,7 +159,7 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() - + @pytest.mark.skip("kto_pair no longer supported in trl") @with_temp_dir def test_kto_pair_lora(self, temp_dir):