Fix prompt assembly for llama (#952)
* start at index 0 * add test to check for missing turns * apply black * Update test_prompt_tokenizers.py * Update src/axolotl/monkeypatch/fastchat_conversation_turns.py Co-authored-by: Motoki Wu <tokestermw@gmail.com> * fix linting * apply black * add more tests for llama/sharegpt * make logic clearer --------- Co-authored-by: Motoki Wu <tokestermw@gmail.com>
This commit is contained in:
@@ -114,6 +114,76 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_sharegpt_llama(self):
|
||||
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||
prompter = ShareGPTPrompterV2(conversation="llama-2")
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
|
||||
def tokenize(conv):
|
||||
return strat.tokenize_prompt(conv)["input_ids"]
|
||||
|
||||
def decode(ids):
|
||||
return strat.tokenizer.decode(ids)
|
||||
|
||||
# Multi-turn conversations
|
||||
multi_turn_conv = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
# fmt: off
|
||||
mt_ids = tokenize(multi_turn_conv)
|
||||
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
|
||||
# Single-turn conversations
|
||||
single_turn_conv = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
}
|
||||
|
||||
st_ids = tokenize(single_turn_conv)
|
||||
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, single-turn
|
||||
no_sys_conv = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
}
|
||||
|
||||
ns_ids = tokenize(no_sys_conv)
|
||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, multi-turn
|
||||
no_sys_mt_conv = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
ns_mt_ids = tokenize(no_sys_mt_conv)
|
||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
# fmt: on
|
||||
|
||||
def test_sharegpt_changes_roles(self):
|
||||
conversation = {
|
||||
"roles": ["USER", "CHARACTER"],
|
||||
|
||||
Reference in New Issue
Block a user