From 21c8e2deabdd08408abe3d4c75cf18e00bc2f30b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 28 May 2023 14:36:33 -0400 Subject: [PATCH] refactor conversation plucking in sharegpt --- src/axolotl/prompt_tokenizers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bfe6fc877..a91a4e2d3 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -268,6 +268,9 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): + def get_conversation_thread(self, prompt): + return prompt["conversations"] + def tokenize_prompt(self, prompt): result = { "input_ids": [], @@ -279,7 +282,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): assistant_token = self._get_assistant_token() try: for i, part in enumerate( - self.prompter.build_prompt(prompt["conversations"]) + self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): if isinstance(part, tuple): if part[0] == "USER:":