tokenization fixes

This commit is contained in:
Wing Lian
2023-05-21 08:33:06 -04:00
parent 1d5ab84486
commit 4ea9a66dbd
4 changed files with 20 additions and 7 deletions

View File

@@ -0,0 +1,8 @@
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)

View File

@@ -38,14 +38,14 @@ class PromptTokenizingStrategy(abc.ABC):
@functools.cache
def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if type(id_or_ids, (int,)):
if isinstance(id_or_ids, (int,)):
return id_or_ids
return False
@functools.cache
def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if type(id_or_ids, (int,)):
if isinstance(id_or_ids, (int,)):
return id_or_ids
return False
@@ -272,15 +272,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
# this is still the user query, we should
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
if user_token:
res = [user_token, *res]
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
elif part[0] == "ASSISTANT:":
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token
res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
if assistant_token:
res = [assistant_token, *res]
res["input_ids"] = [assistant_token, *res["input_ids"]]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
else:

View File

@@ -12,6 +12,7 @@ from datasets import (
from huggingface_hub import hf_hub_download
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
@@ -94,10 +95,13 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
if not ds:
raise Exception("unhandled dataset load")
d_type = d.type
d_type_split = d.type.split(":")
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if d_base_type == "alpaca":
if (ds_strategy := load(d.type, tokenizer, cfg)):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)

View File

@@ -220,7 +220,7 @@ def load_model(
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
if cfg.tokens:
tokenizer.add_tokens(cfg.tokens)
tokenizer.add_tokens(list(cfg.tokens))
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)