diff --git a/README.md b/README.md index 7ebea8678..214bfd14d 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,14 @@ Have dataset(s) in one of the following format (JSONL recommended): ```json {"conversations": [{"role": "...", "value": "..."}]} ``` +- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from` + ```json + {"conversations": [{"role": "...", "value": "..."}]} + ``` +- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny + ```json + {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} + ``` @@ -530,7 +538,7 @@ Try set `fp16: true` Try to turn off xformers. -## Need help? 🙋‍♂️ +## Need help? 🙋♂️ Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index d6367ce7c..40c58bc9c 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -33,12 +33,16 @@ class TokenizedPromptDataset(IterableDataset): def __iter__(self): iterator = iter(self.dataset) + count = 0 # Loop through the entire dataset for example in iterator: try: yield self.prompt_tokenizer.tokenize_prompt(example) + count += 1 except InvalidDataException: pass + if count == 0: + raise RuntimeError("Expected at least one datapoint in dataset.") # TODO this isn't the best since it can't interleave datasets diff --git a/src/axolotl/prompt_strategies/sharegpt_jokes.py b/src/axolotl/prompt_strategies/sharegpt_jokes.py new file mode 100644 index 000000000..ac424bf7c --- /dev/null +++ b/src/axolotl/prompt_strategies/sharegpt_jokes.py @@ -0,0 +1,28 @@ +"""Module for Jokes prompts using sharegpt style """ +from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy +from axolotl.prompters import PromptStyle, ShareGPTPrompter + + +def load(tokenizer, cfg): + return SimpleJokesShareGPTPromptTokenizingStrategy( + ShareGPTPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): + """ + Tokenization strategy for asking bot to tell a joke and then explain why its funny + """ + + # title, text, explanation + def get_conversation_thread(self, prompt): + title = "" if not prompt["title"] else prompt["title"] + " " + return [ + {"from": "human", "value": "Tell me a joke."}, + {"from": "gpt", "value": title + prompt["text"]}, + {"from": "human", "value": "Why is that joke funny?"}, + {"from": "gpt", "value": prompt["explanation"]}, + ] diff --git a/src/axolotl/prompt_strategies/sharegpt_simple.py b/src/axolotl/prompt_strategies/sharegpt_simple.py index 4346663f2..bfe0d164b 100644 --- a/src/axolotl/prompt_strategies/sharegpt_simple.py +++ b/src/axolotl/prompt_strategies/sharegpt_simple.py @@ -13,6 +13,15 @@ def load(tokenizer, cfg): ) +def load_role(tokenizer, cfg): + return SimpleRoleShareGPTPromptTokenizingStrategy( + ShareGPTPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + def load_guanaco(tokenizer, cfg): return GuanacoShareGPTPromptTokenizingStrategy( ShareGPTPrompter(PromptStyle.CHAT.value), @@ -31,6 +40,18 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): return prompt["conversations"] +class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): + """ + basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from + """ + + def get_conversation_thread(self, prompt): + conversations = prompt["conversations"] + # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... + turns = [{"from": t["role"], "value": t["value"]} for t in conversations] + return turns + + class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ sharegpt strategy that remaps oasst data to sharegpt format diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 39c74023b..29cc4446b 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -261,28 +261,33 @@ class Conversation: self.messages.append([role, message]) -conv_vicuna_v1_1 = Conversation( - system="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - roles=["USER", "ASSISTANT"], - messages=[], - offset=0, - sep_style=SeparatorStyle.TWO, - sep=" ", - sep2=" ", -) - - class ShareGPTPrompter: # pylint: disable=too-few-public-methods """ A prompter that generates prompts for the ShareGPT """ - def __init__(self, prompt_style=None): + def __init__(self, prompt_style=None, system_prompt: Optional[str] = None): if prompt_style != PromptStyle.CHAT.value: raise ValueError( f"unsupported prompt_style for ShareGPTPrompter({prompt_style})" ) + system: str = ( + system_prompt + if system_prompt + else ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ) + ) + self._conversation = Conversation( + system=system, + roles=["USER", "ASSISTANT"], + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2=" ", + ) # def match_prompt_style(self): # if self.prompt_style == PromptStyle.chat.value: @@ -300,7 +305,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods # also happens on the data splitting leaving empty conversations raise IndexError - conv = conv_vicuna_v1_1.copy() + conv = self._conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} try: diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index cba964076..9fee2fb9b 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -239,8 +239,15 @@ def load_tokenized_prepared_datasets( ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) else: - logging.error(f"unhandled prompt tokenization strategy: {d.type}") - raise ValueError(f"unhandled prompt tokenization strategy: {d.type}") + suffix = "" + if ":load_" in d.type: + suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?" + logging.error( + f"unhandled prompt tokenization strategy: {d.type}. {suffix}" + ) + raise ValueError( + f"unhandled prompt tokenization strategy: {d.type} {suffix}" + ) logging.info("tokenizing, merging, and shuffling master dataset") samples: List[int] = []