Generalizing the chat_template prompt strategy (#1660) [skip ci]

The strategy now supports configuring several fields: * The data field holding message arrays * the role and
content fields for each message * role mapping from source to target types

additionally this adds a sample llama3-8b instruct template using the chat template
This commit is contained in:
Keith Stevens
2024-05-29 00:24:13 +09:00
committed by GitHub
parent 5f91064040
commit cc11c6bce2
4 changed files with 258 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
"""
tests for chat_template prompt strategy
"""
import unittest
import pytest
@@ -10,8 +11,39 @@ from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
load,
)
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
{
"role": "assistant",
"content": "goodbye",
},
]
}
]
)
@pytest.fixture(name="sharegpt_dataset")
@@ -51,6 +83,87 @@ def fixture_llama3_tokenizer():
return tokenizer
class TestAssistantChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
strategy = load(
llama3_tokenizer,
DictDefault(
{
"train_on_inputs": False,
"sequence_len": 512,
}
),
DictDefault(
{
"chat_template": "llama3",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
"field_messages": "messages",
}
),
)
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
assert input_ids == [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
def test_llama3(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
roles={
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
),
llama3_tokenizer,
False,
512,
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
assert input_ids == [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
class TestSharegptChatTemplateLlama3:
"""
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.