refactor conversation plucking in sharegpt

This commit is contained in:
Wing Lian
2023-05-28 14:36:33 -04:00
parent 8fe12e3bc1
commit 21c8e2deab

View File

@@ -268,6 +268,9 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
result = {
"input_ids": [],
@@ -279,7 +282,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
assistant_token = self._get_assistant_token()
try:
for i, part in enumerate(
self.prompter.build_prompt(prompt["conversations"])
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if part[0] == "USER:":