diff --git a/src/axolotl/prompt_strategies/alpaca_instruct.py b/src/axolotl/prompt_strategies/alpaca_instruct.py new file mode 100644 index 000000000..8f09407ad --- /dev/null +++ b/src/axolotl/prompt_strategies/alpaca_instruct.py @@ -0,0 +1,8 @@ +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +def load(tokenizer, cfg): + return AlpacaPromptTokenizingStrategy( + AlpacaPrompter(PromptStyle.instruct), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py new file mode 100644 index 000000000..3b6cbf0e3 --- /dev/null +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -0,0 +1,100 @@ +import copy +import logging +from collections import defaultdict +from typing import Generator + +from axolotl.prompt_tokenizers import PromptTokenizingStrategy + +IGNORE_TOKEN_ID = -100 + + +class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): + bot_prefix_token_ids = [] + + def __init__(self, prompter, tokenizer, *args, **kwargs): + super().__init__(prompter, tokenizer) + res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) + self.bot_prefix_token_ids = res["input_ids"] + + def tokenize_prompt(self, prompt): + result = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + current_len = 0 + for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): + role, message = part + if role == "system": + prefix = "<|system|>" + # this should include a bos token, no eos token, strip trailing "\n" + if message.endswith("\n"): + message = message[:-8] + res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False) + # everything from this is masked out from the labels + labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) + elif role == "human": + prefix = "<|user|>" + res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True) + # everything from this is masked out from the labels + labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) + elif role == "bot": + prefix = "<|model|>" + res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True) + res["input_ids"] = [*self.bot_prefix_token_ids, *res["input_ids"]] + # mask out the prefix token, rest is not masked out from labels + labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])] + else: + logging.warning(f"unknown role in conversation: {role}") + res = defaultdict(lambda: []) + input_ids = res["input_ids"] + input_len = len(input_ids) + result["input_ids"][current_len : current_len + input_len] = input_ids + result["attention_mask"][current_len : current_len + input_len] = [ + 1 if x != self.tokenizer.pad_token_id else 0 + for x in input_ids + ] + result["labels"][current_len : current_len + input_len] = labels + current_len += input_len + return result + + def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.sequence_len + and add_eos_token + ): + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + + if ( + result["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + + result["labels"] = result["input_ids"].copy() + return result + + +class PygmalionPrompter: + def __init__(self, *args, **kwargs): + pass + + def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]: + for msg in source: + yield msg["role"], msg["value"] + + +def load(tokenizer, cfg): + return PygmalionPromptTokenizingStrategy( + PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index d436face3..a6e886138 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -10,6 +10,7 @@ from datasets import ( concatenate_datasets, ) from huggingface_hub import hf_hub_download +from transformers import PreTrainedTokenizerBase from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset from axolotl.prompt_strategies import load @@ -37,12 +38,14 @@ from axolotl.prompters import ( def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path): + tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( ( str(cfg.sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + + "|" + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -192,7 +195,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa return dataset -def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): +def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path): max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -200,6 +203,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): max_packed_sequence_len, cfg.sequence_len ) # make sure we don't accidentally set it larger than sequence_len + tokenizer_name = tokenizer.__class__.__name__ if cfg.max_packed_sequence_len is not None: # see if we can go ahead and load the stacked dataset seed = f"@{str(cfg.seed)}" if cfg.seed else "" @@ -211,6 +215,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): + str(max_packed_sequence_len) + seed + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + + "|" + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -238,6 +243,11 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): ) dataset = load_from_disk(str(prepared_ds_path)) logging.info("Prepared packed dataset loaded from disk...") + if cfg.push_dataset_to_hub: + logging.info( + f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" + ) + dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True) else: dataset = load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path