fix mistral prompt assembly (#982)
* fix mistral prompts * fix spacing * remove elif
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -25,6 +26,50 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
test_data = {
|
||||
"multi_turn_sys": {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
},
|
||||
"single_turn_sys": {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
},
|
||||
"single_turn_no_sys": {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
},
|
||||
"multi_turn_no_sys": {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def prompt_strat(conversation, tokenizer):
|
||||
"Helper function to create a prompt strategy for testing."
|
||||
prompter = ShareGPTPrompterV2(conversation=conversation)
|
||||
return ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
"""
|
||||
@@ -116,74 +161,68 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
|
||||
def test_sharegpt_llama(self):
|
||||
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||
prompter = ShareGPTPrompterV2(conversation="llama-2")
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
strat = prompt_strat("llama-2", self.tokenizer)
|
||||
|
||||
def tokenize(conv):
|
||||
return strat.tokenize_prompt(conv)["input_ids"]
|
||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||
|
||||
def decode(ids):
|
||||
return strat.tokenizer.decode(ids)
|
||||
|
||||
# Multi-turn conversations
|
||||
multi_turn_conv = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
# fmt: off
|
||||
mt_ids = tokenize(multi_turn_conv)
|
||||
# System message, multi-turn conversations
|
||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
|
||||
# Single-turn conversations
|
||||
single_turn_conv = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
}
|
||||
|
||||
st_ids = tokenize(single_turn_conv)
|
||||
# System message, single-turn conversations
|
||||
st_ids = tokenize(test_data['single_turn_sys'])
|
||||
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, single-turn
|
||||
no_sys_conv = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
}
|
||||
|
||||
ns_ids = tokenize(no_sys_conv)
|
||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, multi-turn
|
||||
no_sys_mt_conv = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
ns_mt_ids = tokenize(no_sys_mt_conv)
|
||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
# fmt: on
|
||||
|
||||
def test_sharegpt_mistral(self):
|
||||
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
|
||||
strat = prompt_strat("mistral", self.tokenizer)
|
||||
|
||||
def tokenize(conv):
|
||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||
|
||||
def decode(ids):
|
||||
return strat.tokenizer.decode(ids)
|
||||
|
||||
# fmt: off
|
||||
# System message, multi-turn conversations
|
||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
|
||||
# System message, single-turn conversations
|
||||
st_ids = tokenize(test_data['single_turn_sys'])
|
||||
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, single-turn
|
||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, multi-turn
|
||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
# fmt: on
|
||||
|
||||
def test_sharegpt_changes_roles(self):
|
||||
conversation = {
|
||||
"roles": ["USER", "CHARACTER"],
|
||||
|
||||
Reference in New Issue
Block a user