Fix prompt assembly for llama (#952)

* start at index 0

* add test to check for missing turns

* apply black

* Update test_prompt_tokenizers.py

* Update src/axolotl/monkeypatch/fastchat_conversation_turns.py

Co-authored-by: Motoki Wu <tokestermw@gmail.com>

* fix linting

* apply black

* add more tests for llama/sharegpt

* make logic clearer

---------

Co-authored-by: Motoki Wu <tokestermw@gmail.com>
This commit is contained in:
Hamel Husain
2023-12-14 10:03:59 -08:00
committed by GitHub
parent 712fd27b3f
commit 5ada140ff0
2 changed files with 82 additions and 5 deletions

View File

@@ -83,14 +83,21 @@ def get_turns( # pylint: disable=too-many-return-statements
yield role + ":", "" yield role + ":", ""
return return
if self.sep_style == SeparatorStyle.LLAMA2: if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message: if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt yield "", system_prompt
else: for i, (role, message) in enumerate(self.messages):
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message: if message:
yield role + " ", message + seps[i % 2] if (i % 2 == 0 and not self.system_message) or (
i % 2 != 0 and self.system_message
):
role = "<s> " + role
yield role + " ", message
else: else:
yield role, "" yield role, ""
return return

View File

@@ -114,6 +114,76 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
in self._caplog.records[0].message in self._caplog.records[0].message
) )
def test_sharegpt_llama(self):
"Make sure the sharegpt/llama is tokenized and formatted correctly."
prompter = ShareGPTPrompterV2(conversation="llama-2")
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
def tokenize(conv):
return strat.tokenize_prompt(conv)["input_ids"]
def decode(ids):
return strat.tokenizer.decode(ids)
# Multi-turn conversations
multi_turn_conv = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
}
# fmt: off
mt_ids = tokenize(multi_turn_conv)
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# Single-turn conversations
single_turn_conv = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
}
st_ids = tokenize(single_turn_conv)
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn
no_sys_conv = {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
}
ns_ids = tokenize(no_sys_conv)
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
# No system message, multi-turn
no_sys_mt_conv = {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
}
ns_mt_ids = tokenize(no_sys_mt_conv)
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on
def test_sharegpt_changes_roles(self): def test_sharegpt_changes_roles(self):
conversation = { conversation = {
"roles": ["USER", "CHARACTER"], "roles": ["USER", "CHARACTER"],