Merge pull request #101 from OpenAccess-AI-Collective/sharegpt-conv

refactor conversation plucking in sharegpt
This commit is contained in:
Wing Lian
2023-05-28 19:43:54 -04:00
committed by GitHub

View File

@@ -268,6 +268,9 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
result = {
"input_ids": [],
@@ -279,7 +282,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
assistant_token = self._get_assistant_token()
try:
for i, part in enumerate(
self.prompter.build_prompt(prompt["conversations"])
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if part[0] == "USER:":