From 1f151c0d52d2d4c78c5e1b1a4ff4fb64cba1f45d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 3 Jun 2024 12:50:44 -0400 Subject: [PATCH] re-enable DPO for tests in modal ci (#1374) * re-enable DPO for tests in modal ci * workaround for training args * don't mixin AxolotlTrainingArguments * fix mixin order so MRO doesn't result in TypeError: non-default argument follows default argument error * use smaller datasets for dpo tests --- .../prompt_strategies/orpo/chat_template.py | 16 +++++++++++----- tests/e2e/test_dpo.py | 16 ++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index a89dee157..bba694856 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -56,7 +56,9 @@ class ORPODatasetParsingStrategy: messages: List[Message] = [] if system := prompt.get("system", None): messages.append(Message(role="system", content=system, label=False)) - messages.append(Message(role="user", content=prompt["prompt"], label=False)) + messages.append( + Message(role="user", content=prompt["chosen"][0]["content"], label=False) + ) messages.append( Message( role="assistant", content=prompt["chosen"][1]["content"], label=True @@ -70,7 +72,9 @@ class ORPODatasetParsingStrategy: messages: List[Message] = [] if system := prompt.get("system", None): messages.append(Message(role="system", content=system, label=False)) - messages.append(Message(role="user", content=prompt["prompt"], label=False)) + messages.append( + Message(role="user", content=prompt["rejected"][0]["content"], label=False) + ) messages.append( Message( role="assistant", content=prompt["rejected"][1]["content"], label=True @@ -152,8 +156,8 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy): def tokenize_prompt(self, prompt): # pass the rejected prompt/row to the Prompter to get the formatted prompt prompt_len = 0 - rejected_message_list = self.dataset_parser.get_rejected_conversation_thread( - prompt + rejected_message_list: MessageList = ( + self.dataset_parser.get_rejected_conversation_thread(prompt) ) input_ids = [] labels = [] @@ -174,7 +178,9 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy): rejected_input_ids = input_ids rejected_labels = labels # pass the chosen prompt/row to the Prompter to get the formatted prompt - chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt) + chosen_message_list: MessageList = ( + self.dataset_parser.get_chosen_conversation_thread(prompt) + ) input_ids = [] labels = [] for _, (part, label) in enumerate( diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index ddd63d827..5d2522bdf 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -21,7 +21,6 @@ LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -@pytest.mark.skip(reason="doesn't seem to work on modal") class TestDPOLlamaLora(unittest.TestCase): """ Test case for DPO Llama models using LoRA @@ -45,8 +44,8 @@ class TestDPOLlamaLora(unittest.TestCase): "rl": "dpo", "datasets": [ { - "path": "Intel/orca_dpo_pairs", - "type": "chatml.intel", + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", "split": "train", }, ], @@ -89,8 +88,8 @@ class TestDPOLlamaLora(unittest.TestCase): "rl": "kto_pair", "datasets": [ { - "path": "Intel/orca_dpo_pairs", - "type": "chatml.intel", + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", "split": "train", }, ], @@ -133,8 +132,8 @@ class TestDPOLlamaLora(unittest.TestCase): "rl": "ipo", "datasets": [ { - "path": "Intel/orca_dpo_pairs", - "type": "chatml.intel", + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", "split": "train", }, ], @@ -180,7 +179,7 @@ class TestDPOLlamaLora(unittest.TestCase): "chat_template": "chatml", "datasets": [ { - "path": "argilla/ultrafeedback-binarized-preferences-cleaned", + "path": "argilla/distilabel-capybara-dpo-7k-binarized", "type": "chat_template.argilla", "split": "train", }, @@ -206,6 +205,7 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @pytest.mark.skip(reason="Fix the implementation") @with_temp_dir def test_kto_lora(self, temp_dir): # pylint: disable=duplicate-code