diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index cca48b1cf..534249c4f 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -86,6 +86,13 @@ def fixture_llama3_tokenizer(): return tokenizer +@pytest.fixture(name="phi3_tokenizer") +def fixture_phi3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct") + + return tokenizer + + class TestAssistantDPOChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. @@ -152,5 +159,36 @@ class TestAssistantDPOChatTemplateLlama3: assert result["rejected"] == "party on<|eot_id|>" +class TestAssistantDPOChatTemplatePhi3: + """ + Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy. + """ + + def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template", + "chat_template": "tokenizer_default", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer) + assert result["prompt"] == ( + "<|user|>\nhello<|end|>\n" + + "<|assistant|>\nhello<|end|>\n" + + "<|user|>\ngoodbye<|end|>\n" + + "<|assistant|>\n" + ) + assert result["chosen"] == "goodbye<|end|>" + assert result["rejected"] == "party on<|end|>" + + if __name__ == "__main__": unittest.main()