Generalizing the chat_template prompt strategy (#1660) [skip ci]
The strategy now supports configuring several fields: * The data field holding message arrays * the role and content fields for each message * role mapping from source to target types additionally this adds a sample llama3-8b instruct template using the chat template
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -10,8 +11,39 @@ from transformers import AutoTokenizer
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplatePrompter,
|
||||
ChatTemplateStrategy,
|
||||
load,
|
||||
)
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="assistant_dataset")
|
||||
def fixture_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "goodbye",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "goodbye",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset")
|
||||
@@ -51,6 +83,87 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestAssistantChatTemplateLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = load(
|
||||
llama3_tokenizer,
|
||||
DictDefault(
|
||||
{
|
||||
"train_on_inputs": False,
|
||||
"sequence_len": 512,
|
||||
}
|
||||
),
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
"field_messages": "messages",
|
||||
}
|
||||
),
|
||||
)
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_templates("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False,
|
||||
512,
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class TestSharegptChatTemplateLlama3:
|
||||
"""
|
||||
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
||||
|
||||
Reference in New Issue
Block a user