Feat: Add sharegpt multirole (#1137)
* feat(prompt): support multiple roles for sharegpt * fix: add handling of empty role back * feat: rebased and allowed more dynamic roles via config * fix: variable * chore: update message * feat: add vicuna format * fix: JSON serializable error * fix: typing * fix: don't remap for unknown keys * fix: add roles to pydantic * feat: add test * chore: remove leftover print * chore: remove leftover comment * chore: remove print * fix: update test to use chatml
This commit is contained in:
@@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="multi_role_dataset")
|
||||
def fixture_multi_role_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "use get_weather(city) to get the weather for a city",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello, what's the weather in New York?",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "let me get that for you",
|
||||
},
|
||||
{
|
||||
"from": "tool",
|
||||
"value": "get_weather(New York)",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "the weather in New York is 70 degrees and sunny",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
@@ -196,3 +228,39 @@ class TestSharegpt:
|
||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, multi_role_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
||||
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
||||
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user