diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 722e1702b..d909c5643 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -26,51 +26,6 @@ class TestDPOLlamaLora(unittest.TestCase): Test case for DPO Llama models using LoRA """ - @with_temp_dir - def test_dpo_nll_lora(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, - "load_in_8bit": True, - "adapter": "lora", - "lora_r": 64, - "lora_alpha": 32, - "lora_dropout": 0.1, - "lora_target_linear": True, - "special_tokens": {}, - "rl": "dpo", - "rpo_alpha": 0.5, - "datasets": [ - { - "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", - "type": "chatml.ultra", - "split": "train", - }, - ], - "num_epochs": 1, - "micro_batch_size": 4, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "paged_adamw_8bit", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "warmup_steps": 5, - "gradient_checkpointing": True, - "gradient_checkpointing_kwargs": {"use_reentrant": True}, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() - @with_temp_dir def test_dpo_lora(self, temp_dir): # pylint: disable=duplicate-code @@ -115,6 +70,51 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @with_temp_dir + def test_dpo_nll_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "dpo", + "rpo_alpha": 0.5, + "datasets": [ + { + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @with_temp_dir def test_dpo_use_weighting(self, temp_dir): # pylint: disable=duplicate-code