feat: add test for levy's dpo case
This commit is contained in:
@@ -86,6 +86,13 @@ def fixture_llama3_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="phi3_tokenizer")
|
||||||
|
def fixture_phi3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestAssistantDPOChatTemplateLlama3:
|
class TestAssistantDPOChatTemplateLlama3:
|
||||||
"""
|
"""
|
||||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
@@ -152,5 +159,36 @@ class TestAssistantDPOChatTemplateLlama3:
|
|||||||
assert result["rejected"] == "party on<|eot_id|>"
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssistantDPOChatTemplatePhi3:
|
||||||
|
"""
|
||||||
|
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
transform_fn = default(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "tokenizer_default",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"type": "chat_template",
|
||||||
|
"chat_template": "tokenizer_default",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|user|>\nhello<|end|>\n"
|
||||||
|
+ "<|assistant|>\nhello<|end|>\n"
|
||||||
|
+ "<|user|>\ngoodbye<|end|>\n"
|
||||||
|
+ "<|assistant|>\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|end|>"
|
||||||
|
assert result["rejected"] == "party on<|end|>"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user