Add Glaive conversation format support (#1365)
* Add Glaive conversation format support * fix black formatting errors * Fix black and pylint formatting errors * only set role_key_tool if provided in the dataset constructor * Update src/axolotl/prompt_strategies/sharegpt.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * sharegpt test * tokenizer test * fix formatting --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test module for sharegpt integration w chatml
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
@@ -8,6 +9,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
SimpleShareGPTPromptTokenizingStrategy,
|
||||
register_chatml_template,
|
||||
)
|
||||
@@ -48,6 +50,18 @@ def fixture_sharegpt_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="glaive_dataset")
|
||||
def fixture_sharegpt_glaive_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
@@ -156,3 +170,29 @@ class TestSharegpt:
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, glaive_dataset, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for testing prompt tokenizers."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
@@ -18,6 +19,7 @@ from axolotl.prompt_strategies.llama2_chat import (
|
||||
Llama2ChatPrompter,
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
@@ -266,6 +268,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
idx = res["input_ids"].index(20255) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_glaive_tool_label_ignore(self):
|
||||
conversation = {
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(13566) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
|
||||
Reference in New Issue
Block a user