Create preprocess CLI (#785)
* Create preprocess CLI * Print prompt template if debugging * Add print for unsupported prompters * Formatting * Formatting * Refactor variables * Formatting * Formatting * Formatting * Formatting
This commit is contained in:
@@ -4,10 +4,12 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||
|
||||
|
||||
class PromptStyle(Enum):
|
||||
@@ -55,20 +57,15 @@ class AlpacaPrompter:
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
def _build_result(self, instruction, input_text, output):
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
if input_text:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
) + self.turn_format.format(instruction=instruction, input=input_text)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
@@ -77,7 +74,21 @@ class AlpacaPrompter:
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
yield self._build_result(instruction, input, output)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
@@ -191,14 +202,14 @@ class ReflectAlpacaPrompter:
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
|
||||
def build_prompt(
|
||||
def _build_result(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
):
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
@@ -212,7 +223,30 @@ class ReflectAlpacaPrompter:
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# pylint: disable=duplicate-code
|
||||
yield self._build_result(
|
||||
instruction,
|
||||
input,
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
@@ -247,7 +281,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
@@ -282,11 +316,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_turns():
|
||||
return conv.get_turns()
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
@@ -304,3 +347,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter:
|
||||
"""
|
||||
A dummy class for custom prompters
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
||||
|
||||
Reference in New Issue
Block a user