diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py new file mode 100644 index 000000000..1cd99bd9f --- /dev/null +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -0,0 +1,8 @@ +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +def load(tokenizer, cfg): + return AlpacaPromptTokenizingStrategy( + AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index c33551135..6c20e7729 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -38,14 +38,14 @@ class PromptTokenizingStrategy(abc.ABC): @functools.cache def _get_user_token(self): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") - if type(id_or_ids, (int,)): + if isinstance(id_or_ids, (int,)): return id_or_ids return False @functools.cache def _get_assistant_token(self): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") - if type(id_or_ids, (int,)): + if isinstance(id_or_ids, (int,)): return id_or_ids return False @@ -272,15 +272,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): # this is still the user query, we should res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True) if user_token: - res = [user_token, *res] + res["input_ids"] = [user_token, *res["input_ids"]] # everything from this is masked out from the labels labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) elif part[0] == "ASSISTANT:": + # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID part = part[0] + part[1] if not assistant_token else part[1] # this should be the assistent response, should end with an eos token res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True) if assistant_token: - res = [assistant_token, *res] + res["input_ids"] = [assistant_token, *res["input_ids"]] # not masked out from labels labels = copy.deepcopy(res["input_ids"]) else: diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 22bc23359..d436face3 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -12,6 +12,7 @@ from datasets import ( from huggingface_hub import hf_hub_download from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset +from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, GPTeacherPromptTokenizingStrategy, @@ -94,10 +95,13 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa if not ds: raise Exception("unhandled dataset load") d_type = d.type - d_type_split = d.type.split(":") + d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if d_base_type == "alpaca": + if (ds_strategy := load(d.type, tokenizer, cfg)): + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + elif d_base_type == "alpaca": ds_strategy = AlpacaPromptTokenizingStrategy( AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0217f062b..a2fe0b494 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -220,7 +220,7 @@ def load_model( for k, v in cfg.special_tokens.items(): tokenizer.add_special_tokens({k: v}) if cfg.tokens: - tokenizer.add_tokens(cfg.tokens) + tokenizer.add_tokens(list(cfg.tokens)) embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len)