use fastchat conversations template (#578)

* use fastchat conversations template

* require fastchat (fschat) pip install

* handle roles dynamically from conversation

* tweak fastchat conversation with a monkeypatch to get individual turns

* fix up so it works with multiple conversation styles, and don't strip the turns

* fix sharegpt fixture now that we're using a more correct tokenization

* use a new prompter and support fastchat conversation type

* use sharegpt from prompt strategies now

* update docs, add chatml template

* add a newline after im_end token

* ensure we correctly set system message

* update per PR feedback to handle deprecated sharegpt types

* don't add duplicate wandb req

* make sharegpt fields configurable from yml

* llama2 fixes

* don't fail fatally when turns are improper
This commit is contained in:
Wing Lian
2023-09-27 12:10:45 -04:00
committed by GitHub
parent 60c7c48c97
commit e7d3e2dbb6
13 changed files with 324 additions and 112 deletions

View File

@@ -180,7 +180,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt:chat`: conversations where `from` is `human`/`gpt` - `sharegpt`: conversations where `from` is `human`/`gpt`
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -269,11 +269,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"prompt": "...", "generation": "..."} {"prompt": "...", "generation": "..."}
``` ```
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from` - `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json ```json
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt - `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -443,6 +443,7 @@ datasets:
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
# custom user prompt # custom user prompt
- path: repo - path: repo

View File

@@ -31,3 +31,4 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29

View File

@@ -0,0 +1,174 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
yield "", system_prompt
else:
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message:
yield role + " ", message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -1,12 +1,35 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter from axolotl.prompters import ShareGPTPrompterV2
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>\n",
)
)
def load(tokenizer, cfg): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return SimpleShareGPTPromptTokenizingStrategy( return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value), ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -15,7 +38,7 @@ def load(tokenizer, cfg):
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy( return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value), ShareGPTPrompterV2(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -24,7 +47,7 @@ def load_role(tokenizer, cfg):
def load_guanaco(tokenizer, cfg): def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy( return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value), ShareGPTPrompterV2(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -1,11 +1,11 @@
"""Module for Jokes prompts using sharegpt style """ """Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg): def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy( return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value), ShareGPTPrompterV2(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -6,8 +6,12 @@ import functools
import logging import logging
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -18,6 +22,8 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
add_get_turns_to_conversation()
class InvalidDataException(Exception): class InvalidDataException(Exception):
""" """
@@ -352,18 +358,21 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result, current_len = tokenize_prompt_default() result, current_len = tokenize_prompt_default()
user_token = self._get_user_token() user_token = self._get_user_token()
assistant_token = self._get_assistant_token() assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
)
try: try:
for _, part in enumerate( for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt)) self.prompter.build_prompt(self.get_conversation_thread(prompt))
): ):
if isinstance(part, tuple): if isinstance(part, tuple):
if part[0] == "USER:": if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1] turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should # this is still the user query, we should
if not part[1].strip(): if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}") LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize( res = self._tokenize(
turn.strip(), turn,
add_eos_token=False, add_eos_token=False,
strip_bos_token=True, strip_bos_token=True,
) )
@@ -371,14 +380,14 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
res["input_ids"] = [user_token, *res["input_ids"]] res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif part[0] == "ASSISTANT:": elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1] turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token # this should be the assistant response, should end with an eos token
if not part[1].strip(): if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}") LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize( res = self._tokenize(
turn.strip(), turn,
add_eos_token=True, add_eos_token=True,
strip_bos_token=True, strip_bos_token=True,
) )
@@ -389,16 +398,17 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
] ]
# not masked out from labels # not masked out from labels
labels = copy.deepcopy(res["input_ids"]) labels = copy.deepcopy(res["input_ids"])
elif part[0] == "SYSTEM:": elif part[0] == "":
part = part[1] # Ignore the system role from preamble turn = part[1]
# this is only ever the first part, should include the bos token and the user query # this is only ever the first part, should include the bos token and the user query
res = self._tokenize( res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False turn, add_eos_token=False, strip_bos_token=False
) )
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else: else:
LOG.warning(f"unhandled role: {part[0]}") LOG.warning(f"unhandled role: {part[0]}")
continue
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result( result, current_len = parse_tokenized_to_result(

View File

@@ -1,9 +1,10 @@
"""Module containing prompters""" """Module containing prompters"""
import dataclasses
import logging import logging
from enum import Enum, auto from enum import Enum
from typing import Generator, List, Optional, Tuple, Union from typing import Generator, Optional, Union
from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
@@ -214,53 +215,6 @@ class ReflectAlpacaPrompter:
yield res yield res
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: Optional[str] = None
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role, message):
self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = ( SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data." "Role did not alternate between turns (gpt and human). Please check your data."
) )
@@ -271,28 +225,27 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
""" """
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None): role_key_human = "human"
if prompt_style != PromptStyle.CHAT.value: role_key_model = "gpt"
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})" def __init__(
) self,
system: str = ( prompt_style=None, # pylint: disable=unused-argument
system_prompt conversation: Optional[Union[str, Conversation]] = None,
if system_prompt role_key_human: Optional[str] = None,
else ( role_key_model: Optional[str] = None,
"A chat between a curious user and an artificial intelligence assistant. " ):
"The assistant gives helpful, detailed, and polite answers to the user's questions." if conversation:
) if isinstance(conversation, Conversation):
) self._conversation = conversation
self._conversation = Conversation( else:
system=system, self._conversation = get_conv_template(conversation)
roles=["USER", "ASSISTANT"], else:
messages=[], self._conversation = get_conv_template("vicuna_v1.1")
offset=0, if role_key_human:
sep_style=SeparatorStyle.TWO, self.role_key_human = role_key_human
sep=" ", if role_key_model:
sep2=" ", self.role_key_model = role_key_model
)
def build_prompt(self, source) -> Generator[str, None, None]: def build_prompt(self, source) -> Generator[str, None, None]:
if len(source) < 2: if len(source) < 2:
@@ -306,17 +259,14 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
# Add the conversation system prompt if provided, otherwise use the default one # Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system": if source[0]["from"] == "system":
conv.system = source[0]["value"] conv.set_system_message(source[0]["value"])
source.pop(0) source.pop(0)
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
try: try:
# Apply prompt templates # Apply prompt templates
if ( if source[0]["from"] not in roles:
source[0]["from"] not in roles
or roles[source[0]["from"]] != conv.roles[0]
):
# Skip the first one if it is not from human # Skip the first one if it is not from human
source = source[1:] source = source[1:]
except IndexError as err: except IndexError as err:
@@ -326,8 +276,29 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE if role != conv.roles[j % 2]:
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_prompt(): for part in conv.get_turns():
if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}")
yield part yield part
class ShareGPTPrompterV2(ShareGPTPrompter):
"""
A V2 prompter that generates prompts for the ShareGPT
"""
def __init__(
self,
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
)

View File

@@ -278,6 +278,25 @@ def validate_config(cfg):
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing" "`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
) )
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if ds_cfg.type == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = "sharegpt"
if "sharegpt_simple" in ds_cfg.type:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -25,7 +25,6 @@ from axolotl.prompt_tokenizers import (
GPTeacherPromptTokenizingStrategy, GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy, SummarizeTLDRPromptTokenizingStrategy,
) )
from axolotl.prompters import ( from axolotl.prompters import (
@@ -35,7 +34,6 @@ from axolotl.prompters import (
MultipleChoiceConcisePrompter, MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
ShareGPTPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -320,15 +318,6 @@ def load_tokenized_prepared_datasets(
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else: else:
suffix = "" suffix = ""
if ":load_" in d.type: if ":load_" in d.type:

View File

@@ -33,5 +33,6 @@ def check_example_labels(example, tokenizer, text_only=False):
LOG.info(" ".join(colored_tokens)) LOG.info(" ".join(colored_tokens))
LOG.info("\n\n\n") LOG.info("\n\n\n")
print(" ".join(colored_tokens))
return " ".join(colored_tokens) return " ".join(colored_tokens)

File diff suppressed because one or more lines are too long

View File

@@ -21,7 +21,7 @@ from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy,
) )
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -60,7 +60,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) as fin: ) as fin:
data = fin.read() data = fin.read()
tokenized_conversation = json.loads(data) tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompter("chat") prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy( strat = ShareGPTPromptTokenizingStrategy(
prompter, prompter,
self.tokenizer, self.tokenizer,
@@ -79,7 +79,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) as fin: ) as fin:
data = fin.read() data = fin.read()
conversation = json.loads(data) conversation = json.loads(data)
prompter = ShareGPTPrompter("chat") prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy( strat = ShareGPTPromptTokenizingStrategy(
prompter, prompter,
self.tokenizer, self.tokenizer,

View File

@@ -374,3 +374,26 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
def test_sharegpt_deprecation(self):
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt:chat` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt"
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt_simple` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt:load_role"