diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 38c4e787b..156949fd8 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -115,8 +115,8 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() - '''@with_temp_dir - def test_dpo_use_weighting(self, temp_dir): + @with_temp_dir + def test_dpo_nll_use_weighting(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -158,7 +158,7 @@ class TestDPOLlamaLora(unittest.TestCase): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()''' + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() @pytest.mark.skip("kto_pair no longer supported in trl") @with_temp_dir