Add Glaive conversation format support (#1365)

* Add Glaive conversation format support

* fix black formatting errors

* Fix black and pylint formatting errors

* only set role_key_tool if provided in the dataset constructor

* Update src/axolotl/prompt_strategies/sharegpt.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* sharegpt test

* tokenizer test

* fix formatting

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Brian Fitzgerald
2024-03-10 19:50:25 -05:00
committed by GitHub
parent b0ee9ec734
commit b7d8a7dc4d
6 changed files with 184 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
"""Module for testing prompt tokenizers."""
import json
import logging
import unittest
@@ -18,6 +19,7 @@ from axolotl.prompt_strategies.llama2_chat import (
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
@@ -266,6 +268,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
def test_glaive_tool_label_ignore(self):
conversation = {
"system": "SYSTEM: This is a system prompt",
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
}
prompter = ShareGPTPrompterV2()
strat = GlaiveShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(13566) # assistant token
assert res["labels"][idx] == -100
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts