use fastchat conversations template (#578)
* use fastchat conversations template * require fastchat (fschat) pip install * handle roles dynamically from conversation * tweak fastchat conversation with a monkeypatch to get individual turns * fix up so it works with multiple conversation styles, and don't strip the turns * fix sharegpt fixture now that we're using a more correct tokenization * use a new prompter and support fastchat conversation type * use sharegpt from prompt strategies now * update docs, add chatml template * add a newline after im_end token * ensure we correctly set system message * update per PR feedback to handle deprecated sharegpt types * don't add duplicate wandb req * make sharegpt fields configurable from yml * llama2 fixes * don't fail fatally when turns are improper
This commit is contained in:
@@ -180,7 +180,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -269,11 +269,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"prompt": "...", "generation": "..."}
|
{"prompt": "...", "generation": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -443,6 +443,7 @@ datasets:
|
|||||||
data_files: # Optional[str] path to source data files
|
data_files: # Optional[str] path to source data files
|
||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
|
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
|
|
||||||
# custom user prompt
|
# custom user prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
|
|||||||
@@ -31,3 +31,4 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
|
fschat==0.2.29
|
||||||
|
|||||||
174
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
174
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
monkeypatch to add a get_turns method
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Generator, Tuple
|
||||||
|
|
||||||
|
from fastchat.conversation import SeparatorStyle
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(self) -> str:
|
||||||
|
ret = ""
|
||||||
|
for role, msg in self.get_turns():
|
||||||
|
ret += role + msg
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_turns( # pylint: disable=too-many-return-statements
|
||||||
|
self,
|
||||||
|
) -> Generator[Tuple[str, str], None, None]:
|
||||||
|
"""Get the prompt for generation."""
|
||||||
|
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt + seps[0]
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ": ", "" # must be end with a space
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||||
|
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + "\n", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + "\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||||
|
yield "", system_prompt
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role, message + self.sep
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield role, message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.RWKV:
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||||
|
"\n\n", "\n"
|
||||||
|
) + "\n\n"
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
if self.system_message:
|
||||||
|
yield "", system_prompt
|
||||||
|
else:
|
||||||
|
yield "", "[INST] "
|
||||||
|
for i, (role, message) in enumerate(self.messages[1:]):
|
||||||
|
if message:
|
||||||
|
yield role + " ", message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||||
|
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||||
|
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||||
|
if system_prompt:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if i % 2 == 0:
|
||||||
|
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||||
|
|
||||||
|
if message:
|
||||||
|
yield f"{role}:", f"{message}{self.sep}"
|
||||||
|
else:
|
||||||
|
yield f"{role}:", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.CHATML:
|
||||||
|
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + "\n", message + self.sep + "\n"
|
||||||
|
else:
|
||||||
|
yield role + "\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||||
|
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
prefix = "<s>" if i % 2 == 0 else ""
|
||||||
|
if message:
|
||||||
|
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.DOLLY:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
suffix = "\n\n" if i % 2 == 1 else ""
|
||||||
|
yield role + ":\n", message + seps[i % 2] + suffix
|
||||||
|
else:
|
||||||
|
yield role + ":\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||||
|
yield "", system_prompt
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", "<s>" + message + "</s>"
|
||||||
|
else:
|
||||||
|
yield role + ": " + "<s>", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ROBIN:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ":\n", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ":\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||||
|
if self.system_message:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
|
||||||
|
def add_get_turns_to_conversation():
|
||||||
|
import fastchat.conversation
|
||||||
|
|
||||||
|
fastchat.conversation.Conversation.get_turns = get_turns
|
||||||
|
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||||
@@ -1,12 +1,35 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="chatml",
|
||||||
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
|
system_message="You are a helpful assistant.",
|
||||||
|
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||||
|
sep_style=SeparatorStyle.CHATML,
|
||||||
|
sep="<|im_end|>\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
return SimpleShareGPTPromptTokenizingStrategy(
|
return SimpleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
@@ -15,7 +38,7 @@ def load(tokenizer, cfg):
|
|||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
ShareGPTPrompterV2(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
@@ -24,7 +47,7 @@ def load_role(tokenizer, cfg):
|
|||||||
|
|
||||||
def load_guanaco(tokenizer, cfg):
|
def load_guanaco(tokenizer, cfg):
|
||||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
ShareGPTPrompterV2(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Module for Jokes prompts using sharegpt style """
|
"""Module for Jokes prompts using sharegpt style """
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg):
|
||||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
ShareGPTPrompterV2(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -6,8 +6,12 @@ import functools
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation
|
||||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||||
|
add_get_turns_to_conversation,
|
||||||
|
)
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -18,6 +22,8 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
|||||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||||
|
|
||||||
|
add_get_turns_to_conversation()
|
||||||
|
|
||||||
|
|
||||||
class InvalidDataException(Exception):
|
class InvalidDataException(Exception):
|
||||||
"""
|
"""
|
||||||
@@ -352,18 +358,21 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
result, current_len = tokenize_prompt_default()
|
result, current_len = tokenize_prompt_default()
|
||||||
user_token = self._get_user_token()
|
user_token = self._get_user_token()
|
||||||
assistant_token = self._get_assistant_token()
|
assistant_token = self._get_assistant_token()
|
||||||
|
conversation: Conversation = (
|
||||||
|
self.prompter._conversation # pylint: disable=protected-access
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
for _, part in enumerate(
|
for _, part in enumerate(
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||||
):
|
):
|
||||||
if isinstance(part, tuple):
|
if isinstance(part, tuple):
|
||||||
if part[0] == "USER:":
|
if conversation.roles[0] in part[0]:
|
||||||
turn = part[0] + part[1] if not user_token else part[1]
|
turn = part[0] + part[1] if not user_token else part[1]
|
||||||
# this is still the user query, we should
|
# this is still the user query, we should
|
||||||
if not part[1].strip():
|
if not part[1].strip():
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
LOG.warning(f"user turn has empty text: {prompt}")
|
||||||
res = self._tokenize(
|
res = self._tokenize(
|
||||||
turn.strip(),
|
turn,
|
||||||
add_eos_token=False,
|
add_eos_token=False,
|
||||||
strip_bos_token=True,
|
strip_bos_token=True,
|
||||||
)
|
)
|
||||||
@@ -371,14 +380,14 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif part[0] == "ASSISTANT:":
|
elif conversation.roles[1] in part[0]:
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||||
# this should be the assistant response, should end with an eos token
|
# this should be the assistant response, should end with an eos token
|
||||||
if not part[1].strip():
|
if not part[1].strip():
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||||
res = self._tokenize(
|
res = self._tokenize(
|
||||||
turn.strip(),
|
turn,
|
||||||
add_eos_token=True,
|
add_eos_token=True,
|
||||||
strip_bos_token=True,
|
strip_bos_token=True,
|
||||||
)
|
)
|
||||||
@@ -389,16 +398,17 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
]
|
]
|
||||||
# not masked out from labels
|
# not masked out from labels
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
elif part[0] == "SYSTEM:":
|
elif part[0] == "":
|
||||||
part = part[1] # Ignore the system role from preamble
|
turn = part[1]
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
res = self._tokenize(
|
res = self._tokenize(
|
||||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
turn, add_eos_token=False, strip_bos_token=False
|
||||||
)
|
)
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
else:
|
else:
|
||||||
LOG.warning(f"unhandled role: {part[0]}")
|
LOG.warning(f"unhandled role: {part[0]}")
|
||||||
|
continue
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Module containing prompters"""
|
"""Module containing prompters"""
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum, auto
|
from enum import Enum
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation, get_conv_template
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
@@ -214,53 +215,6 @@ class ReflectAlpacaPrompter:
|
|||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
|
||||||
"""Different separator style."""
|
|
||||||
|
|
||||||
SINGLE = auto()
|
|
||||||
TWO = auto()
|
|
||||||
DOLLY = auto()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO clean this 💩 up
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Conversation:
|
|
||||||
"""A class that keeps all conversation history."""
|
|
||||||
|
|
||||||
system: str
|
|
||||||
roles: List[str]
|
|
||||||
messages: List[List[str]]
|
|
||||||
offset: int
|
|
||||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
|
||||||
sep: str = "###"
|
|
||||||
sep2: Optional[str] = None
|
|
||||||
|
|
||||||
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
|
||||||
# seps = [self.sep, self.sep2]
|
|
||||||
preamble = self.system + self.sep
|
|
||||||
yield ("SYSTEM:", preamble)
|
|
||||||
for _, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
yield (role + ":", " " + message)
|
|
||||||
else:
|
|
||||||
LOG.warning(f"role with empty message: {role}")
|
|
||||||
yield (role + ":", "")
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return Conversation(
|
|
||||||
system=self.system,
|
|
||||||
roles=self.roles,
|
|
||||||
messages=[[x, y] for x, y in self.messages],
|
|
||||||
offset=self.offset,
|
|
||||||
sep_style=self.sep_style,
|
|
||||||
sep=self.sep,
|
|
||||||
sep2=self.sep2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def append_message(self, role, message):
|
|
||||||
self.messages.append([role, message])
|
|
||||||
|
|
||||||
|
|
||||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||||
)
|
)
|
||||||
@@ -271,28 +225,27 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
A prompter that generates prompts for the ShareGPT
|
A prompter that generates prompts for the ShareGPT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
role_key_human = "human"
|
||||||
if prompt_style != PromptStyle.CHAT.value:
|
role_key_model = "gpt"
|
||||||
raise ValueError(
|
|
||||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
def __init__(
|
||||||
)
|
self,
|
||||||
system: str = (
|
prompt_style=None, # pylint: disable=unused-argument
|
||||||
system_prompt
|
conversation: Optional[Union[str, Conversation]] = None,
|
||||||
if system_prompt
|
role_key_human: Optional[str] = None,
|
||||||
else (
|
role_key_model: Optional[str] = None,
|
||||||
"A chat between a curious user and an artificial intelligence assistant. "
|
):
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
if conversation:
|
||||||
)
|
if isinstance(conversation, Conversation):
|
||||||
)
|
self._conversation = conversation
|
||||||
self._conversation = Conversation(
|
else:
|
||||||
system=system,
|
self._conversation = get_conv_template(conversation)
|
||||||
roles=["USER", "ASSISTANT"],
|
else:
|
||||||
messages=[],
|
self._conversation = get_conv_template("vicuna_v1.1")
|
||||||
offset=0,
|
if role_key_human:
|
||||||
sep_style=SeparatorStyle.TWO,
|
self.role_key_human = role_key_human
|
||||||
sep=" ",
|
if role_key_model:
|
||||||
sep2=" ",
|
self.role_key_model = role_key_model
|
||||||
)
|
|
||||||
|
|
||||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
@@ -306,17 +259,14 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
# Add the conversation system prompt if provided, otherwise use the default one
|
# Add the conversation system prompt if provided, otherwise use the default one
|
||||||
if source[0]["from"] == "system":
|
if source[0]["from"] == "system":
|
||||||
conv.system = source[0]["value"]
|
conv.set_system_message(source[0]["value"])
|
||||||
source.pop(0)
|
source.pop(0)
|
||||||
|
|
||||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Apply prompt templates
|
# Apply prompt templates
|
||||||
if (
|
if source[0]["from"] not in roles:
|
||||||
source[0]["from"] not in roles
|
|
||||||
or roles[source[0]["from"]] != conv.roles[0]
|
|
||||||
):
|
|
||||||
# Skip the first one if it is not from human
|
# Skip the first one if it is not from human
|
||||||
source = source[1:]
|
source = source[1:]
|
||||||
except IndexError as err:
|
except IndexError as err:
|
||||||
@@ -326,8 +276,29 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
conv.messages = []
|
conv.messages = []
|
||||||
for j, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
if role != conv.roles[j % 2]:
|
||||||
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
for part in conv.get_prompt():
|
for part in conv.get_turns():
|
||||||
|
if part[0] and not part[1]:
|
||||||
|
LOG.warning(f"role with empty message: {part[0]}")
|
||||||
yield part
|
yield part
|
||||||
|
|
||||||
|
|
||||||
|
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||||
|
"""
|
||||||
|
A V2 prompter that generates prompts for the ShareGPT
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conversation: Optional[Union[str, Conversation]] = None,
|
||||||
|
role_key_human: Optional[str] = None,
|
||||||
|
role_key_model: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_human=role_key_human,
|
||||||
|
role_key_model=role_key_model,
|
||||||
|
)
|
||||||
|
|||||||
@@ -278,6 +278,25 @@ def validate_config(cfg):
|
|||||||
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
|
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.datasets:
|
||||||
|
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||||
|
if ds_cfg.type == "sharegpt:chat":
|
||||||
|
LOG.warning(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cfg.datasets[idx].type = "sharegpt"
|
||||||
|
if "sharegpt_simple" in ds_cfg.type:
|
||||||
|
LOG.warning(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||||
|
"sharegpt_simple", "sharegpt"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from axolotl.prompt_tokenizers import (
|
|||||||
GPTeacherPromptTokenizingStrategy,
|
GPTeacherPromptTokenizingStrategy,
|
||||||
JeopardyPromptTokenizingStrategy,
|
JeopardyPromptTokenizingStrategy,
|
||||||
OpenAssistantPromptTokenizingStrategy,
|
OpenAssistantPromptTokenizingStrategy,
|
||||||
ShareGPTPromptTokenizingStrategy,
|
|
||||||
SummarizeTLDRPromptTokenizingStrategy,
|
SummarizeTLDRPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import (
|
from axolotl.prompters import (
|
||||||
@@ -35,7 +34,6 @@ from axolotl.prompters import (
|
|||||||
MultipleChoiceConcisePrompter,
|
MultipleChoiceConcisePrompter,
|
||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
ShareGPTPrompter,
|
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -320,15 +318,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "sharegpt":
|
|
||||||
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
else:
|
else:
|
||||||
suffix = ""
|
suffix = ""
|
||||||
if ":load_" in d.type:
|
if ":load_" in d.type:
|
||||||
|
|||||||
@@ -33,5 +33,6 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
|
|
||||||
LOG.info(" ".join(colored_tokens))
|
LOG.info(" ".join(colored_tokens))
|
||||||
LOG.info("\n\n\n")
|
LOG.info("\n\n\n")
|
||||||
|
print(" ".join(colored_tokens))
|
||||||
|
|
||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|||||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -21,7 +21,7 @@ from axolotl.prompt_tokenizers import (
|
|||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
ShareGPTPromptTokenizingStrategy,
|
ShareGPTPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
) as fin:
|
) as fin:
|
||||||
data = fin.read()
|
data = fin.read()
|
||||||
tokenized_conversation = json.loads(data)
|
tokenized_conversation = json.loads(data)
|
||||||
prompter = ShareGPTPrompter("chat")
|
prompter = ShareGPTPrompterV2()
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
@@ -79,7 +79,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
) as fin:
|
) as fin:
|
||||||
data = fin.read()
|
data = fin.read()
|
||||||
conversation = json.loads(data)
|
conversation = json.loads(data)
|
||||||
prompter = ShareGPTPrompter("chat")
|
prompter = ShareGPTPrompterV2()
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|||||||
@@ -374,3 +374,26 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_sharegpt_deprecation(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"`type: sharegpt:chat` will soon be deprecated." in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
assert cfg.datasets[0].type == "sharegpt"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"`type: sharegpt_simple` will soon be deprecated." in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
assert cfg.datasets[0].type == "sharegpt:load_role"
|
||||||
|
|||||||
Reference in New Issue
Block a user