fix mistral prompt assembly (#982)
* fix mistral prompts * fix spacing * remove elif
This commit is contained in:
@@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role + ":", ""
|
yield role + ":", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
if self.messages:
|
if self.messages:
|
||||||
# For llama, the system message is incorporated into the first human instruction
|
# For llama, the system message is incorporated into the first human instruction
|
||||||
@@ -101,6 +101,28 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
||||||
|
contains_sys_msg = False
|
||||||
|
if self.system_message:
|
||||||
|
contains_sys_msg = True
|
||||||
|
if self.messages:
|
||||||
|
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
|
||||||
|
first_role, first_msg = self.messages[0]
|
||||||
|
if first_role == self.roles[0]:
|
||||||
|
system_prompt = self.system_template.format(
|
||||||
|
system_message=" " + self.system_message
|
||||||
|
)
|
||||||
|
system_prompt += first_msg
|
||||||
|
self.messages.pop(0)
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message and i == 0 and not contains_sys_msg:
|
||||||
|
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
||||||
|
elif message:
|
||||||
|
yield role + " ", message
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -25,6 +26,50 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
|||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"multi_turn_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"single_turn_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"single_turn_no_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"multi_turn_no_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_strat(conversation, tokenizer):
|
||||||
|
"Helper function to create a prompt strategy for testing."
|
||||||
|
prompter = ShareGPTPrompterV2(conversation=conversation)
|
||||||
|
return ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
@@ -116,74 +161,68 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
|
|
||||||
def test_sharegpt_llama(self):
|
def test_sharegpt_llama(self):
|
||||||
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||||
prompter = ShareGPTPrompterV2(conversation="llama-2")
|
strat = prompt_strat("llama-2", self.tokenizer)
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
|
|
||||||
def tokenize(conv):
|
def tokenize(conv):
|
||||||
return strat.tokenize_prompt(conv)["input_ids"]
|
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||||
|
|
||||||
def decode(ids):
|
def decode(ids):
|
||||||
return strat.tokenizer.decode(ids)
|
return strat.tokenizer.decode(ids)
|
||||||
|
|
||||||
# Multi-turn conversations
|
|
||||||
multi_turn_conv = {
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "123"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
mt_ids = tokenize(multi_turn_conv)
|
# System message, multi-turn conversations
|
||||||
|
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||||
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
|
||||||
# Single-turn conversations
|
# System message, single-turn conversations
|
||||||
single_turn_conv = {
|
st_ids = tokenize(test_data['single_turn_sys'])
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
st_ids = tokenize(single_turn_conv)
|
|
||||||
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||||
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
# No system message, single-turn
|
# No system message, single-turn
|
||||||
no_sys_conv = {
|
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||||
"conversations": [
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
ns_ids = tokenize(no_sys_conv)
|
|
||||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
# No system message, multi-turn
|
# No system message, multi-turn
|
||||||
no_sys_mt_conv = {
|
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||||
"conversations": [
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "123"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
ns_mt_ids = tokenize(no_sys_mt_conv)
|
|
||||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
def test_sharegpt_mistral(self):
|
||||||
|
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
|
||||||
|
strat = prompt_strat("mistral", self.tokenizer)
|
||||||
|
|
||||||
|
def tokenize(conv):
|
||||||
|
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||||
|
|
||||||
|
def decode(ids):
|
||||||
|
return strat.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# System message, multi-turn conversations
|
||||||
|
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||||
|
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
|
||||||
|
# System message, single-turn conversations
|
||||||
|
st_ids = tokenize(test_data['single_turn_sys'])
|
||||||
|
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
||||||
|
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, single-turn
|
||||||
|
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||||
|
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||||
|
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, multi-turn
|
||||||
|
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||||
|
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def test_sharegpt_changes_roles(self):
|
def test_sharegpt_changes_roles(self):
|
||||||
conversation = {
|
conversation = {
|
||||||
"roles": ["USER", "CHARACTER"],
|
"roles": ["USER", "CHARACTER"],
|
||||||
|
|||||||
Reference in New Issue
Block a user