diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 582c35ebd..8b3c88fee 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -371,15 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): ] # not masked out from labels labels = copy.deepcopy(res["input_ids"]) + elif part[0] == "SYSTEM:": + part = part[1] # Ignore the system role from preamble + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize( + part.strip(), add_eos_token=False, strip_bos_token=False + ) + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: logging.warning(f"unhandled role: {part[0]}") - else: - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - part.strip(), add_eos_token=False, strip_bos_token=False - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 1a2535e19..39c74023b 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -3,7 +3,7 @@ import dataclasses import logging from enum import Enum, auto -from typing import Generator, List, Optional, Union +from typing import Generator, List, Optional, Tuple, Union IGNORE_TOKEN_ID = -100 @@ -235,16 +235,16 @@ class Conversation: sep: str = "###" sep2: Optional[str] = None - def get_prompt(self) -> Generator[str, None, None]: + def get_prompt(self) -> Generator[Tuple[str, str], None, None]: # seps = [self.sep, self.sep2] preamble = self.system + self.sep - yield preamble + yield ("SYSTEM:", preamble) for _, (role, message) in enumerate(self.messages): if message: - yield role + ":" + " " + message + yield (role + ":", " " + message) else: logging.warning(f"role with empty message: {role}") - yield role + ":" + yield (role + ":", "") def copy(self): return Conversation(