diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index ae2f56fa1..0f8c31d6a 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -49,7 +49,7 @@ class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): return ( prompt["message_1"], "", - prompt["message_1"], + prompt["message_2"], ) diff --git a/src/axolotl/prompt_strategies/sharegpt_simple.py b/src/axolotl/prompt_strategies/sharegpt_simple.py new file mode 100644 index 000000000..4346663f2 --- /dev/null +++ b/src/axolotl/prompt_strategies/sharegpt_simple.py @@ -0,0 +1,46 @@ +"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" + +from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy +from axolotl.prompters import PromptStyle, ShareGPTPrompter + + +def load(tokenizer, cfg): + return SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +def load_guanaco(tokenizer, cfg): + return GuanacoShareGPTPromptTokenizingStrategy( + ShareGPTPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): + """ + basic sharegpt strategy to grab conversations from the sample row + """ + + def get_conversation_thread(self, prompt): + return prompt["conversations"] + + +class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy that remaps oasst data to sharegpt format + """ + + def get_conversation_thread(self, prompt): + conversations = prompt["conversations"] + # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... + role_map = {"prompter": "human", "assistant": "gpt"} + turns = [ + {"from": role_map[t["role"]], "value": t["text"]} for t in conversations + ] + return turns