feat: handles chat_template requiring specific user/assistant order

This commit is contained in:
NanoCode012
2024-10-14 14:00:55 +07:00
parent e5cd55cff9
commit d101cfc125
2 changed files with 42 additions and 5 deletions

View File

@@ -53,6 +53,7 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "dummy"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
@@ -63,7 +64,7 @@ def default(
)
result["chosen"] = tokenizer.apply_chat_template(
[chosen],
[dummy_user_message, chosen],
add_generation_prompt=False,
chat_template=chat_template_string,
tokenize=False,
@@ -72,7 +73,7 @@ def default(
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
[dummy_user_message, rejected],
add_generation_prompt=False,
chat_template=chat_template_string,
tokenize=False,

View File

@@ -93,6 +93,13 @@ def fixture_phi3_tokenizer():
return tokenizer
@pytest.fixture(name="gemma_tokenizer")
def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
return tokenizer
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -106,7 +113,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
"type": "chat_template",
}
],
}
@@ -131,7 +138,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
"type": "chat_template",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
@@ -173,7 +180,6 @@ class TestAssistantDPOChatTemplatePhi3:
"datasets": [
{
"type": "chat_template",
"chat_template": "tokenizer_default",
}
],
}
@@ -190,5 +196,35 @@ class TestAssistantDPOChatTemplatePhi3:
assert result["rejected"] == "party on<|end|>"
class TestAssistantDPOChatTemplateGemma:
"""
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
assert result["prompt"] == (
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
+ "<start_of_turn>model\nhello<end_of_turn>\n"
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
+ "<start_of_turn>model\n"
)
assert result["chosen"] == "goodbye<end_of_turn>"
assert result["rejected"] == "party on<end_of_turn>"
if __name__ == "__main__":
unittest.main()