diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 60333b33b..85490349b 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -53,6 +53,7 @@ def default( "role": role_map[sample[field_rejected][field_message_role]], "content": sample[field_rejected][field_message_content], } + dummy_user_message = {"role": "user", "content": "dummy"} result = {} result["prompt"] = tokenizer.apply_chat_template( @@ -63,7 +64,7 @@ def default( ) result["chosen"] = tokenizer.apply_chat_template( - [chosen], + [dummy_user_message, chosen], add_generation_prompt=False, chat_template=chat_template_string, tokenize=False, @@ -72,7 +73,7 @@ def default( result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["rejected"] = tokenizer.apply_chat_template( - [rejected], + [dummy_user_message, rejected], add_generation_prompt=False, chat_template=chat_template_string, tokenize=False, diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index 534249c4f..740edc22f 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -93,6 +93,13 @@ def fixture_phi3_tokenizer(): return tokenizer +@pytest.fixture(name="gemma_tokenizer") +def fixture_gemma_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a") + + return tokenizer + + class TestAssistantDPOChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. @@ -106,7 +113,7 @@ class TestAssistantDPOChatTemplateLlama3: "chat_template": "llama3", "datasets": [ { - "chat_template": "llama3", + "type": "chat_template", } ], } @@ -131,7 +138,7 @@ class TestAssistantDPOChatTemplateLlama3: "chat_template": "llama3", "datasets": [ { - "chat_template": "llama3", + "type": "chat_template", "field_messages": "conversation", "field_chosen": "better", "field_rejected": "worse", @@ -173,7 +180,6 @@ class TestAssistantDPOChatTemplatePhi3: "datasets": [ { "type": "chat_template", - "chat_template": "tokenizer_default", } ], } @@ -190,5 +196,35 @@ class TestAssistantDPOChatTemplatePhi3: assert result["rejected"] == "party on<|end|>" +class TestAssistantDPOChatTemplateGemma: + """ + Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy. + """ + + def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer) + assert result["prompt"] == ( + "user\nhello\n" + + "model\nhello\n" + + "user\ngoodbye\n" + + "model\n" + ) + assert result["chosen"] == "goodbye" + assert result["rejected"] == "party on" + + if __name__ == "__main__": unittest.main()