Add a chat_template prompt strategy for DPO (#1725)
* Implementing a basic chat_template strategy for DPO datasets This mimics the sft chat_template strategy such that users can: * Specify the messages field * Specify the per message role and content fields * speicfy the chosen and rejected fields * Let the tokenizer construct the raw prompt * Ensure the chosen and rejected fields don't have any prefix tokens * Adding additional dpo chat template unittests * Rename test class
This commit is contained in:
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="assistant_dataset")
|
||||
def fixture_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "goodbye",
|
||||
},
|
||||
],
|
||||
"chosen": {
|
||||
"role": "assistant",
|
||||
"content": "goodbye",
|
||||
},
|
||||
"rejected": {
|
||||
"role": "assistant",
|
||||
"content": "party on",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="custom_assistant_dataset")
|
||||
def fixture_custom_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversation": [
|
||||
{
|
||||
"speaker": "human",
|
||||
"text": "hello",
|
||||
},
|
||||
{
|
||||
"speaker": "agent",
|
||||
"text": "hello",
|
||||
},
|
||||
{
|
||||
"speaker": "human",
|
||||
"text": "goodbye",
|
||||
},
|
||||
],
|
||||
"better": {
|
||||
"speaker": "agent",
|
||||
"text": "goodbye",
|
||||
},
|
||||
"worse": {
|
||||
"speaker": "agent",
|
||||
"text": "party on",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer")
|
||||
def fixture_llama3_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
tokenizer.eos_token = "<|eot_id|>"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestAssistantDPOChatTemplateLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<|begin_of_text|>"
|
||||
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "better",
|
||||
"field_rejected": "worse",
|
||||
"message_field_role": "speaker",
|
||||
"message_field_content": "text",
|
||||
"roles": {
|
||||
"user": ["human"],
|
||||
"assistant": ["agent"],
|
||||
"system": ["sys"],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<|begin_of_text|>"
|
||||
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user