Fix sharegpt prompt

This commit is contained in:
NanoCode012
2023-05-31 00:38:08 +09:00
parent cfcc549f6b
commit 25eeeeba0b
2 changed files with 13 additions and 12 deletions

View File

@@ -371,15 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
] ]
# not masked out from labels # not masked out from labels
labels = copy.deepcopy(res["input_ids"]) labels = copy.deepcopy(res["input_ids"])
elif part[0] == "SYSTEM:":
part = part[1] # Ignore the system role from preamble
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else: else:
logging.warning(f"unhandled role: {part[0]}") logging.warning(f"unhandled role: {part[0]}")
else:
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result( result, current_len = parse_tokenized_to_result(

View File

@@ -3,7 +3,7 @@
import dataclasses import dataclasses
import logging import logging
from enum import Enum, auto from enum import Enum, auto
from typing import Generator, List, Optional, Union from typing import Generator, List, Optional, Tuple, Union
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
@@ -235,16 +235,16 @@ class Conversation:
sep: str = "###" sep: str = "###"
sep2: Optional[str] = None sep2: Optional[str] = None
def get_prompt(self) -> Generator[str, None, None]: def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2] # seps = [self.sep, self.sep2]
preamble = self.system + self.sep preamble = self.system + self.sep
yield preamble yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages): for _, (role, message) in enumerate(self.messages):
if message: if message:
yield role + ":" + " " + message yield (role + ":", " " + message)
else: else:
logging.warning(f"role with empty message: {role}") logging.warning(f"role with empty message: {role}")
yield role + ":" yield (role + ":", "")
def copy(self): def copy(self):
return Conversation( return Conversation(