use fastchat conversations template (#578)
* use fastchat conversations template * require fastchat (fschat) pip install * handle roles dynamically from conversation * tweak fastchat conversation with a monkeypatch to get individual turns * fix up so it works with multiple conversation styles, and don't strip the turns * fix sharegpt fixture now that we're using a more correct tokenization * use a new prompter and support fastchat conversation type * use sharegpt from prompt strategies now * update docs, add chatml template * add a newline after im_end token * ensure we correctly set system message * update per PR feedback to handle deprecated sharegpt types * don't add duplicate wandb req * make sharegpt fields configurable from yml * llama2 fixes * don't fail fatally when turns are improper
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""Module containing prompters"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
@@ -214,53 +215,6 @@ class ReflectAlpacaPrompter:
|
||||
yield res
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
DOLLY = auto()
|
||||
|
||||
|
||||
# TODO clean this 💩 up
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: Optional[str] = None
|
||||
|
||||
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
||||
# seps = [self.sep, self.sep2]
|
||||
preamble = self.system + self.sep
|
||||
yield ("SYSTEM:", preamble)
|
||||
for _, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield (role + ":", " " + message)
|
||||
else:
|
||||
LOG.warning(f"role with empty message: {role}")
|
||||
yield (role + ":", "")
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
)
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
@@ -271,28 +225,27 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
||||
if prompt_style != PromptStyle.CHAT.value:
|
||||
raise ValueError(
|
||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
||||
)
|
||||
system: str = (
|
||||
system_prompt
|
||||
if system_prompt
|
||||
else (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
)
|
||||
)
|
||||
self._conversation = Conversation(
|
||||
system=system,
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_style=None, # pylint: disable=unused-argument
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
self._conversation = conversation
|
||||
else:
|
||||
self._conversation = get_conv_template(conversation)
|
||||
else:
|
||||
self._conversation = get_conv_template("vicuna_v1.1")
|
||||
if role_key_human:
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
if len(source) < 2:
|
||||
@@ -306,17 +259,14 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
|
||||
# Add the conversation system prompt if provided, otherwise use the default one
|
||||
if source[0]["from"] == "system":
|
||||
conv.system = source[0]["value"]
|
||||
conv.set_system_message(source[0]["value"])
|
||||
source.pop(0)
|
||||
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
if (
|
||||
source[0]["from"] not in roles
|
||||
or roles[source[0]["from"]] != conv.roles[0]
|
||||
):
|
||||
if source[0]["from"] not in roles:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as err:
|
||||
@@ -326,8 +276,29 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
conv.messages = []
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
if role != conv.roles[j % 2]:
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_prompt():
|
||||
for part in conv.get_turns():
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
A V2 prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user