feat: handles chat_template requiring specific user/assistant order
This commit is contained in:
@@ -53,6 +53,7 @@ def default(
|
||||
"role": role_map[sample[field_rejected][field_message_role]],
|
||||
"content": sample[field_rejected][field_message_content],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "dummy"}
|
||||
|
||||
result = {}
|
||||
result["prompt"] = tokenizer.apply_chat_template(
|
||||
@@ -63,7 +64,7 @@ def default(
|
||||
)
|
||||
|
||||
result["chosen"] = tokenizer.apply_chat_template(
|
||||
[chosen],
|
||||
[dummy_user_message, chosen],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
@@ -72,7 +73,7 @@ def default(
|
||||
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||
|
||||
result["rejected"] = tokenizer.apply_chat_template(
|
||||
[rejected],
|
||||
[dummy_user_message, rejected],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
|
||||
@@ -93,6 +93,13 @@ def fixture_phi3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="gemma_tokenizer")
|
||||
def fixture_gemma_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestAssistantDPOChatTemplateLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||
@@ -106,7 +113,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"type": "chat_template",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -131,7 +138,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "better",
|
||||
"field_rejected": "worse",
|
||||
@@ -173,7 +180,6 @@ class TestAssistantDPOChatTemplatePhi3:
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template",
|
||||
"chat_template": "tokenizer_default",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -190,5 +196,35 @@ class TestAssistantDPOChatTemplatePhi3:
|
||||
assert result["rejected"] == "party on<|end|>"
|
||||
|
||||
|
||||
class TestAssistantDPOChatTemplateGemma:
|
||||
"""
|
||||
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
|
||||
+ "<start_of_turn>model\nhello<end_of_turn>\n"
|
||||
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
|
||||
+ "<start_of_turn>model\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<end_of_turn>"
|
||||
assert result["rejected"] == "party on<end_of_turn>"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user