Fix sharegpt prompt
This commit is contained in:
@@ -371,15 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
]
|
]
|
||||||
# not masked out from labels
|
# not masked out from labels
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
|
elif part[0] == "SYSTEM:":
|
||||||
|
part = part[1] # Ignore the system role from preamble
|
||||||
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
|
res = self._tokenize(
|
||||||
|
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||||
|
)
|
||||||
|
# everything from this is masked out from the labels
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
else:
|
else:
|
||||||
logging.warning(f"unhandled role: {part[0]}")
|
logging.warning(f"unhandled role: {part[0]}")
|
||||||
else:
|
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
|
||||||
res = self._tokenize(
|
|
||||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
|
||||||
)
|
|
||||||
# everything from this is masked out from the labels
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Generator, List, Optional, Union
|
from typing import Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
@@ -235,16 +235,16 @@ class Conversation:
|
|||||||
sep: str = "###"
|
sep: str = "###"
|
||||||
sep2: Optional[str] = None
|
sep2: Optional[str] = None
|
||||||
|
|
||||||
def get_prompt(self) -> Generator[str, None, None]:
|
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
||||||
# seps = [self.sep, self.sep2]
|
# seps = [self.sep, self.sep2]
|
||||||
preamble = self.system + self.sep
|
preamble = self.system + self.sep
|
||||||
yield preamble
|
yield ("SYSTEM:", preamble)
|
||||||
for _, (role, message) in enumerate(self.messages):
|
for _, (role, message) in enumerate(self.messages):
|
||||||
if message:
|
if message:
|
||||||
yield role + ":" + " " + message
|
yield (role + ":", " " + message)
|
||||||
else:
|
else:
|
||||||
logging.warning(f"role with empty message: {role}")
|
logging.warning(f"role with empty message: {role}")
|
||||||
yield role + ":"
|
yield (role + ":", "")
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return Conversation(
|
return Conversation(
|
||||||
|
|||||||
Reference in New Issue
Block a user