Fix sharegpt prompt

This commit is contained in:
NanoCode012
2023-05-31 00:38:08 +09:00
parent cfcc549f6b
commit 25eeeeba0b
2 changed files with 13 additions and 12 deletions

View File

@@ -371,15 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "SYSTEM:":
part = part[1] # Ignore the system role from preamble
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
logging.warning(f"unhandled role: {part[0]}")
else:
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(

View File

@@ -3,7 +3,7 @@
import dataclasses
import logging
from enum import Enum, auto
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Tuple, Union
IGNORE_TOKEN_ID = -100
@@ -235,16 +235,16 @@ class Conversation:
sep: str = "###"
sep2: Optional[str] = None
def get_prompt(self) -> Generator[str, None, None]:
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield preamble
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield role + ":" + " " + message
yield (role + ":", " " + message)
else:
logging.warning(f"role with empty message: {role}")
yield role + ":"
yield (role + ":", "")
def copy(self):
return Conversation(