feat: add test for levy's dpo case

This commit is contained in:
NanoCode012
2024-10-11 12:56:46 +07:00
parent ef942b6efc
commit dd87d8c438

View File

@@ -86,6 +86,13 @@ def fixture_llama3_tokenizer():
return tokenizer
@pytest.fixture(name="phi3_tokenizer")
def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
return tokenizer
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -152,5 +159,36 @@ class TestAssistantDPOChatTemplateLlama3:
assert result["rejected"] == "party on<|eot_id|>"
class TestAssistantDPOChatTemplatePhi3:
"""
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
"chat_template": "tokenizer_default",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == (
"<|user|>\nhello<|end|>\n"
+ "<|assistant|>\nhello<|end|>\n"
+ "<|user|>\ngoodbye<|end|>\n"
+ "<|assistant|>\n"
)
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
if __name__ == "__main__":
unittest.main()