Add Glaive conversation format support (#1365)

* Add Glaive conversation format support

* fix black formatting errors

* Fix black and pylint formatting errors

* only set role_key_tool if provided in the dataset constructor

* Update src/axolotl/prompt_strategies/sharegpt.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* sharegpt test

* tokenizer test

* fix formatting

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Brian Fitzgerald
2024-03-10 19:50:25 -05:00
committed by GitHub
parent b0ee9ec734
commit b7d8a7dc4d
6 changed files with 184 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
"""
Test module for sharegpt integration w chatml
"""
import pytest
from datasets import Dataset
from tokenizers import AddedToken
@@ -8,6 +9,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.sharegpt import (
GlaiveShareGPTPromptTokenizingStrategy,
SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template,
)
@@ -48,6 +50,18 @@ def fixture_sharegpt_dataset():
)
@pytest.fixture(name="glaive_dataset")
def fixture_sharegpt_glaive_dataset():
return Dataset.from_list(
[
{
"system": "SYSTEM: This is a system prompt",
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
}
]
)
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
@@ -156,3 +170,29 @@ class TestSharegpt:
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
]
# fmt: on
def test_chatml_glaive(self, glaive_dataset, tokenizer):
strategy = GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
role_key_model=None,
role_key_human=None,
),
tokenizer,
True, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, glaive_dataset, process_count=1
)
labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
1, # bos
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
]
# fmt: on

View File

@@ -1,4 +1,5 @@
"""Module for testing prompt tokenizers."""
import json
import logging
import unittest
@@ -18,6 +19,7 @@ from axolotl.prompt_strategies.llama2_chat import (
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
@@ -266,6 +268,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
def test_glaive_tool_label_ignore(self):
conversation = {
"system": "SYSTEM: This is a system prompt",
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
}
prompter = ShareGPTPrompterV2()
strat = GlaiveShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(13566) # assistant token
assert res["labels"][idx] == -100
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts