Fix prompt assembly for llama (#952)
* start at index 0 * add test to check for missing turns * apply black * Update test_prompt_tokenizers.py * Update src/axolotl/monkeypatch/fastchat_conversation_turns.py Co-authored-by: Motoki Wu <tokestermw@gmail.com> * fix linting * apply black * add more tests for llama/sharegpt * make logic clearer --------- Co-authored-by: Motoki Wu <tokestermw@gmail.com>
This commit is contained in:
@@ -83,14 +83,21 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
yield role + ":", ""
|
yield role + ":", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
|
if self.messages:
|
||||||
|
# For llama, the system message is incorporated into the first human instruction
|
||||||
|
first_role, first_msg = self.messages[0]
|
||||||
|
if first_role == self.roles[0]:
|
||||||
|
system_prompt += first_msg
|
||||||
|
self.messages.pop(0)
|
||||||
yield "", system_prompt
|
yield "", system_prompt
|
||||||
else:
|
for i, (role, message) in enumerate(self.messages):
|
||||||
yield "", "[INST] "
|
|
||||||
for i, (role, message) in enumerate(self.messages[1:]):
|
|
||||||
if message:
|
if message:
|
||||||
yield role + " ", message + seps[i % 2]
|
if (i % 2 == 0 and not self.system_message) or (
|
||||||
|
i % 2 != 0 and self.system_message
|
||||||
|
):
|
||||||
|
role = "<s> " + role
|
||||||
|
yield role + " ", message
|
||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -114,6 +114,76 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
in self._caplog.records[0].message
|
in self._caplog.records[0].message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_sharegpt_llama(self):
|
||||||
|
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||||
|
prompter = ShareGPTPrompterV2(conversation="llama-2")
|
||||||
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize(conv):
|
||||||
|
return strat.tokenize_prompt(conv)["input_ids"]
|
||||||
|
|
||||||
|
def decode(ids):
|
||||||
|
return strat.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
# Multi-turn conversations
|
||||||
|
multi_turn_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# fmt: off
|
||||||
|
mt_ids = tokenize(multi_turn_conv)
|
||||||
|
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
|
||||||
|
# Single-turn conversations
|
||||||
|
single_turn_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
st_ids = tokenize(single_turn_conv)
|
||||||
|
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||||
|
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, single-turn
|
||||||
|
no_sys_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
ns_ids = tokenize(no_sys_conv)
|
||||||
|
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||||
|
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, multi-turn
|
||||||
|
no_sys_mt_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
ns_mt_ids = tokenize(no_sys_mt_conv)
|
||||||
|
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def test_sharegpt_changes_roles(self):
|
def test_sharegpt_changes_roles(self):
|
||||||
conversation = {
|
conversation = {
|
||||||
"roles": ["USER", "CHARACTER"],
|
"roles": ["USER", "CHARACTER"],
|
||||||
|
|||||||
Reference in New Issue
Block a user