From 59bb2197ed4b0438745a72103e808cd2dea697fe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 7 Jun 2023 09:51:29 -0400 Subject: [PATCH] fix camel ai, add guanaco/oasst mapping for sharegpt --- src/axolotl/prompt_strategies/alpaca_chat.py | 2 +- .../prompt_strategies/sharegpt_simple.py | 46 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/prompt_strategies/sharegpt_simple.py diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index ae2f56fa1..0f8c31d6a 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -49,7 +49,7 @@ class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): return ( prompt["message_1"], "", - prompt["message_1"], + prompt["message_2"], ) diff --git a/src/axolotl/prompt_strategies/sharegpt_simple.py b/src/axolotl/prompt_strategies/sharegpt_simple.py new file mode 100644 index 000000000..4346663f2 --- /dev/null +++ b/src/axolotl/prompt_strategies/sharegpt_simple.py @@ -0,0 +1,46 @@ +"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" + +from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy +from axolotl.prompters import PromptStyle, ShareGPTPrompter + + +def load(tokenizer, cfg): + return SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +def load_guanaco(tokenizer, cfg): + return GuanacoShareGPTPromptTokenizingStrategy( + ShareGPTPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): + """ + basic sharegpt strategy to grab conversations from the sample row + """ + + def get_conversation_thread(self, prompt): + return prompt["conversations"] + + +class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy that remaps oasst data to sharegpt format + """ + + def get_conversation_thread(self, prompt): + conversations = prompt["conversations"] + # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... + role_map = {"prompter": "human", "assistant": "gpt"} + turns = [ + {"from": role_map[t["role"]], "value": t["text"]} for t in conversations + ] + return turns