Allow in message objects the additional key `weight`, which can be set to 0 (or 1) to cause that message to be masked out (or left unmasked) for training (similar to [1]). This is helpful for training the model to be robust and capable of error recovery upon a bad assistant message. A missing `weight` key defaults to weight 1, to guarantee downward compatibility. [1]: https://github.com/mistralai/mistral-finetune
499 lines
17 KiB
Python
499 lines
17 KiB
Python
"""
|
|
Test module for sharegpt integration w chatml
|
|
"""
|
|
|
|
import pytest
|
|
from datasets import Dataset
|
|
from tokenizers import AddedToken
|
|
from transformers import AutoTokenizer
|
|
|
|
from axolotl.datasets import TokenizedPromptDataset
|
|
from axolotl.prompt_strategies.sharegpt import (
|
|
GlaiveShareGPTPromptTokenizingStrategy,
|
|
SimpleShareGPTPromptTokenizingStrategy,
|
|
register_chatml_template,
|
|
register_llama3_template,
|
|
)
|
|
from axolotl.prompters import ShareGPTPrompterV2
|
|
|
|
register_chatml_template()
|
|
register_llama3_template()
|
|
|
|
|
|
@pytest.fixture(name="sharegpt_dataset")
|
|
def fixture_sharegpt_dataset():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"conversations": [
|
|
{
|
|
"from": "system",
|
|
"value": "repeat",
|
|
},
|
|
{
|
|
"from": "human",
|
|
"value": "hello",
|
|
},
|
|
{
|
|
"from": "gpt",
|
|
"value": "hello",
|
|
},
|
|
{
|
|
"from": "human",
|
|
"value": "goodbye",
|
|
},
|
|
{
|
|
"from": "gpt",
|
|
"value": "goodbye",
|
|
},
|
|
]
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="sharegpt_dataset_with_weights")
|
|
def fixture_sharegpt_dataset_with_weights():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"conversations": [
|
|
{
|
|
"from": "system",
|
|
"value": "repeat",
|
|
},
|
|
{
|
|
"from": "human",
|
|
"value": "hello",
|
|
"weight": 1,
|
|
},
|
|
{
|
|
"from": "gpt",
|
|
"value": "hello",
|
|
"weight": 0,
|
|
},
|
|
{
|
|
"from": "human",
|
|
"value": "rehello",
|
|
"weight": 0,
|
|
},
|
|
{
|
|
"from": "gpt",
|
|
"value": "rehello",
|
|
"weight": 1,
|
|
},
|
|
{
|
|
"from": "human",
|
|
"value": "goodbye",
|
|
},
|
|
{
|
|
"from": "gpt",
|
|
"value": "goodbye",
|
|
"weight": 0,
|
|
},
|
|
]
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="glaive_dataset")
|
|
def fixture_sharegpt_glaive_dataset():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"system": "SYSTEM: This is a system prompt",
|
|
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="multi_role_dataset")
|
|
def fixture_multi_role_dataset():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"conversations": [
|
|
{
|
|
"from": "system",
|
|
"value": "use get_weather(city) to get the weather for a city",
|
|
},
|
|
{
|
|
"from": "human",
|
|
"value": "hello, what's the weather in New York?",
|
|
},
|
|
{
|
|
"from": "gpt",
|
|
"value": "let me get that for you",
|
|
},
|
|
{
|
|
"from": "tool",
|
|
"value": "get_weather(New York)",
|
|
},
|
|
{
|
|
"from": "gpt",
|
|
"value": "the weather in New York is 70 degrees and sunny",
|
|
},
|
|
]
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="tokenizer")
|
|
def fixture_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
|
)
|
|
tokenizer.add_special_tokens(
|
|
{
|
|
"eos_token": AddedToken(
|
|
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
|
)
|
|
}
|
|
)
|
|
tokenizer.add_tokens(
|
|
[
|
|
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
|
]
|
|
)
|
|
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture(name="llama3_tokenizer")
|
|
def fixture_llama3_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
|
tokenizer.eos_token = "<|eot_id|>"
|
|
|
|
return tokenizer
|
|
|
|
|
|
class TestSharegptLlama3:
|
|
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
|
|
|
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="llama3",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
llama3_tokenizer,
|
|
False, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset, process_count=1
|
|
)
|
|
|
|
input_ids = dataset_wrapper[0]["input_ids"]
|
|
|
|
# fmt: off
|
|
assert input_ids == [
|
|
128000, # bos
|
|
128006, 9125, 128007, # system header
|
|
271, 31724, 128009, # sys prompt, eot
|
|
128006, 882, 128007, # user header
|
|
271, 15339, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 15339, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 19045, 29474, 128009,
|
|
128006, 78191, 128007,
|
|
271, 19045, 29474, 128009,
|
|
]
|
|
# fmt: on
|
|
|
|
def test_tokenization_with_weights(
|
|
self, sharegpt_dataset_with_weights, llama3_tokenizer
|
|
):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="llama3",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
llama3_tokenizer,
|
|
False, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
)
|
|
|
|
input_ids = dataset_wrapper[0]["input_ids"]
|
|
|
|
# fmt: off
|
|
assert input_ids == [
|
|
128000, # bos
|
|
128006, 9125, 128007, # system header
|
|
271, 31724, 128009, # sys prompt, eot
|
|
128006, 882, 128007, # user header
|
|
271, 15339, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 15339, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 11310, 4896, 128009,
|
|
128006, 78191, 128007,
|
|
271, 11310, 4896, 128009,
|
|
128006, 882, 128007,
|
|
271, 19045, 29474, 128009,
|
|
128006, 78191, 128007,
|
|
271, 19045, 29474, 128009,
|
|
]
|
|
# fmt: on
|
|
|
|
|
|
class TestSharegptChatML:
|
|
"""
|
|
Test class for sharegpt prompter
|
|
"""
|
|
|
|
def test_no_double_im_end(self, sharegpt_dataset, tokenizer):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="chatml",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
tokenizer,
|
|
False, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset, process_count=1
|
|
)
|
|
|
|
input_ids = dataset_wrapper[0]["input_ids"]
|
|
# fmt: off
|
|
assert input_ids == [
|
|
# 28705, 13, is " \n"
|
|
1, # bos
|
|
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
|
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
]
|
|
# fmt: on
|
|
|
|
def test_no_double_im_end_with_weights(
|
|
self, sharegpt_dataset_with_weights, tokenizer
|
|
):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="chatml",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
tokenizer,
|
|
False, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
)
|
|
|
|
input_ids = dataset_wrapper[0]["input_ids"]
|
|
# fmt: off
|
|
assert input_ids == [
|
|
# 28705, 13, is " \n"
|
|
1, # bos
|
|
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
|
32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
|
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
]
|
|
# fmt: on
|
|
|
|
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="chatml",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
tokenizer,
|
|
False, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset, process_count=1
|
|
)
|
|
|
|
labels = dataset_wrapper[0]["labels"]
|
|
# fmt: off
|
|
assert labels == [
|
|
-100, # bos
|
|
-100, -100, -100, -100, -100, -100, -100, # system
|
|
-100, -100, -100, -100, -100, -100, -100, # human
|
|
-100, -100, 13, 21558, 32000, 28705, 13, # gpt
|
|
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
-100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
]
|
|
# fmt: on
|
|
|
|
def test_no_train_on_input_with_weights(
|
|
self, sharegpt_dataset_with_weights, tokenizer
|
|
):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="chatml",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
tokenizer,
|
|
False, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
)
|
|
|
|
labels = dataset_wrapper[0]["labels"]
|
|
# fmt: off
|
|
assert labels == [
|
|
-100, # bos
|
|
-100, -100, -100, -100, -100, -100, -100, # system
|
|
-100, -100, -100, -100, -100, -100, -100, # human
|
|
-100, -100, -100, -100, -100, -100, -100, # gpt with weight zero
|
|
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
-100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt
|
|
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero
|
|
]
|
|
# fmt: on
|
|
|
|
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="chatml",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
tokenizer,
|
|
True, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset, process_count=1
|
|
)
|
|
|
|
labels = dataset_wrapper[0]["labels"]
|
|
# fmt: off
|
|
assert labels == [
|
|
1, # bos
|
|
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
|
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
]
|
|
# fmt: on
|
|
|
|
def test_w_train_on_input_with_weights(
|
|
self, sharegpt_dataset_with_weights, tokenizer
|
|
):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="chatml",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
tokenizer,
|
|
True, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
)
|
|
|
|
labels = dataset_wrapper[0]["labels"]
|
|
# fmt: off
|
|
assert labels == [
|
|
1, # bos
|
|
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
-100, -100, -100, -100, -100, -100, -100, # gpt with weight 0
|
|
-100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0
|
|
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
|
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0
|
|
]
|
|
# fmt: on
|
|
|
|
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
|
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(
|
|
conversation="chatml",
|
|
role_key_model=None,
|
|
role_key_human=None,
|
|
),
|
|
tokenizer,
|
|
True, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, glaive_dataset, process_count=1
|
|
)
|
|
|
|
labels = dataset_wrapper[0]["labels"]
|
|
# fmt: off
|
|
assert labels == [
|
|
1, # bos
|
|
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
|
|
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
|
]
|
|
# fmt: on
|
|
|
|
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
|
tokenizer,
|
|
False, # train_on_inputs
|
|
2048, # sequence_len
|
|
)
|
|
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
strategy, multi_role_dataset, process_count=1
|
|
)
|
|
|
|
input_ids = dataset_wrapper[0]["input_ids"]
|
|
# fmt: off
|
|
assert input_ids == [
|
|
1, # bos
|
|
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
|
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
|
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
|
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
|
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
|
]
|
|
# fmt: on
|
|
|
|
labels = dataset_wrapper[0]["labels"]
|
|
# fmt: off
|
|
assert labels == [
|
|
-100, # bos
|
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
|
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
|
]
|
|
# fmt: on
|