diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 6ac7cbafe..7a7f61a8e 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,10 +1,15 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" + from typing import Any, Dict, Optional from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompters import ShareGPTPrompterV2 +from axolotl.utils.tokenization import ( + chatml_to_conversation, + merge_consecutive_messages, +) def register_chatml_template(system_message=None): @@ -19,6 +24,16 @@ def register_chatml_template(system_message=None): sep="<|im_end|>", ) ) + register_conv_template( + Conversation( + name="chatml_glaive", + system_template="<|im_start|>system\n{system_message}", + system_message=system_message, + roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"], + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + ) + ) def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): @@ -77,6 +92,20 @@ def load_guanaco(tokenizer, cfg): ) +def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] + if ds_cfg and "conversation" in ds_cfg + else "chatml_glaive" + ) + return GlaiveShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2(conversation=conversation), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ basic sharegpt strategy to grab conversations from the sample row @@ -158,3 +187,15 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt {"from": role_map[t["role"]], "value": t["content"]} for t in conversations ] return turns + + +class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy that remaps glaive data to sharegpt format + """ + + def get_conversation_thread(self, prompt): + conversation = chatml_to_conversation(prompt) + conversation = merge_consecutive_messages(conversation) + + return conversation diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index a5c243f7e..7e62a0cd4 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -360,11 +360,19 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): LOG.warning(f"expected tuple, got {part}") continue - user, assistant = conversation.roles + tool_role_label = None + if len(conversation.roles) == 3: + ( + user_role_label, + assistant_role_label, + tool_role_label, + ) = conversation.roles + else: + user_role_label, assistant_role_label = conversation.roles role, content = part # Uses "in" because role contains extra characters - if user in role: + if user_role_label in role: role = ( role.replace(role_remap[0]["from"], role_remap[0]["to"]) if role_remap @@ -384,7 +392,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif assistant in role: + elif assistant_role_label in role: role = ( role.replace(role_remap[1]["from"], role_remap[1]["to"]) if role_remap @@ -426,6 +434,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif tool_role_label and tool_role_label in role: + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: LOG.warning(f"unhandled role: {role}") continue diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 748db1a16..fa181f916 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -267,6 +267,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods role_key_human = "human" role_key_model = "gpt" + # Optional, only used for tool usage datasets. + role_key_tool = None def __init__( self, @@ -274,6 +276,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + role_key_tool: Optional[str] = None, ): if conversation: if isinstance(conversation, Conversation): @@ -286,6 +289,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods self.role_key_human = role_key_human if role_key_model: self.role_key_model = role_key_model + if role_key_tool: + self.role_key_tool = role_key_tool def _build_result(self, source): if len(source) < 2: @@ -303,6 +308,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods source.pop(0) roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} + if self.role_key_tool: + roles[self.role_key_tool] = conv.roles[2] try: # Apply prompt templates diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 7f63a92fe..afbdef877 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -2,6 +2,8 @@ import logging +import re +from typing import Dict, List from termcolor import colored @@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False): LOG.info("\n\n\n") return " ".join(colored_tokens) + + +GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] +GLAIVE_TO_SHAREGPT_ROLE = { + "SYSTEM": "system", + "USER": "human", + "ASSISTANT": "gpt", + "FUNCTION RESPONSE": "tool", +} + +GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ") + + +def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]: + """ + Converts a ChatML formatted row to a list of messages in ShareGPT format. + Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb. + """ + + system_prompt = row.get("system") + if system_prompt: + system_prompt = system_prompt.removeprefix("SYSTEM: ") + + chat_str = row["chat"] + chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s] + + chat_msg_dicts = [ + {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value} + for role, value in zip(chat_msgs[::2], chat_msgs[1::2]) + ] + + if system_prompt: + chat_msg_dicts = [ + {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt} + ] + chat_msg_dicts + + return chat_msg_dicts + + +def merge_consecutive_messages(messages): + """ + Merge consecutive messages from the same sender into a single message. + This can be useful with datasets that contain multiple consecutive tool calls. + """ + + merged_messages = [] + current_from = None + current_message = "" + + for msg in messages: + if current_from == msg["from"]: + current_message += msg["value"] + else: + if current_from is not None: + merged_messages.append({"from": current_from, "value": current_message}) + current_from = msg["from"] + current_message = msg["value"] + + if current_from is not None: + merged_messages.append({"from": current_from, "value": current_message}) + + return merged_messages diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index 19f8217e0..c9290b220 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -1,6 +1,7 @@ """ Test module for sharegpt integration w chatml """ + import pytest from datasets import Dataset from tokenizers import AddedToken @@ -8,6 +9,7 @@ from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.sharegpt import ( + GlaiveShareGPTPromptTokenizingStrategy, SimpleShareGPTPromptTokenizingStrategy, register_chatml_template, ) @@ -48,6 +50,18 @@ def fixture_sharegpt_dataset(): ) +@pytest.fixture(name="glaive_dataset") +def fixture_sharegpt_glaive_dataset(): + return Dataset.from_list( + [ + { + "system": "SYSTEM: This is a system prompt", + "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>", + } + ] + ) + + @pytest.fixture(name="tokenizer") def fixture_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") @@ -156,3 +170,29 @@ class TestSharegpt: 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt ] # fmt: on + + def test_chatml_glaive(self, glaive_dataset, tokenizer): + strategy = GlaiveShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + True, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, glaive_dataset, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + 1, # bos + 32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system + 32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human + 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt + ] + # fmt: on diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index cf662d95f..077b63b37 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -1,4 +1,5 @@ """Module for testing prompt tokenizers.""" + import json import logging import unittest @@ -18,6 +19,7 @@ from axolotl.prompt_strategies.llama2_chat import ( Llama2ChatPrompter, LLama2ChatTokenizingStrategy, ) +from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, @@ -266,6 +268,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase): idx = res["input_ids"].index(20255) # assistant token assert res["labels"][idx] == -100 + def test_glaive_tool_label_ignore(self): + conversation = { + "system": "SYSTEM: This is a system prompt", + "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>", + } + prompter = ShareGPTPrompterV2() + strat = GlaiveShareGPTPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + with self._caplog.at_level(logging.WARNING): + res = strat.tokenize_prompt(conversation) + idx = res["input_ids"].index(13566) # assistant token + assert res["labels"][idx] == -100 + def test_no_sys_prompt(self): """ tests the interface between the user and assistant parts