Files
axolotl/tests/prompt_strategies/test_chat_templates_mistral.py

291 lines
11 KiB
Python

"""Test chat templates for mistral-common wrapper tokenizer"""
import unittest
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
# pylint: disable=duplicate-code
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
# check bos, eos, pad, unk are accessible properties
assert magistral_tokenizer.bos_token_id == 1
assert magistral_tokenizer.eos_token_id == 2
assert magistral_tokenizer.pad_token_id == 11
assert magistral_tokenizer.unk_token_id == 0
assert magistral_tokenizer.pad_token == "<pad>"
assert magistral_tokenizer.eos_token == "</s>"
assert magistral_tokenizer.bos_token == "<s>"
assert magistral_tokenizer.unk_token == "<unk>"
strategy = MistralStrategy(
MistralPrompter(
magistral_tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=magistral_tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
# test chat template masking without system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template masking with system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
17, # [SYSTEM_PROMPT]
4568, # You
1584, # are
1261, # a
20351, # helpful
27089, # assistant
1046, # .
18, # [/SYSTEM_PROMPT]
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [SYSTEM_PROMPT]
-100, # You
-100, # are
-100, # a
-100, # helpful
-100, # assistant
-100, # .
-100, # [/SYSTEM_PROMPT]
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template with tools
res = strategy.tokenize_prompt(
{
"tools": [
{
"type": "function",
"function": {
"name": "multiples",
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
"parameters": {
"type": "object",
"properties": {
"number": {
"type": "integer",
"description": "The number to find multiples of.",
},
"limit": {
"type": "integer",
"description": "The upper limit for the multiples.",
},
},
"required": ["number", "limit"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "multiples",
"arguments": {
"number": 16,
"limit": 2,
},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "multiples",
"content": "1,2",
},
{"role": "assistant", "content": "The multiples of 16 is 1 and 2."},
],
}
)
# fmt: off
assert res["input_ids"] == [
1, # bos
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
assert res["labels"] == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
# fmt: on
# test chat template with tokenize=False
res = magistral_tokenizer.apply_chat_template(
[
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
],
tokenize=False,
)
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
# test encode
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True)
assert res == [
1, # bos
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
2, # eos
]
# test decode no skip special tokens
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False)
assert decoded_res == "<s>Hello, how are you?</s>"
# test decode skip special tokens
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True)
assert decoded_res == "Hello, how are you?"
# test encode no special tokens
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False)
assert res == [
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
]
# test convert ids to tokens
res = magistral_tokenizer.convert_ids_to_tokens(res)
# spacing are needed as we are converting without decoding
assert res == ["Hello", ",", " how", " are", " you", "?"]
if __name__ == "__main__":
unittest.main()