diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index d9acf5715..0e166f6f0 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -106,7 +106,7 @@ class ConstantLengthDataset(IterableDataset): } else: logging.warning( - "dropping batch due to tensor size mismatch" + f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" ) buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer_len = 0 diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 0bb936e4f..dcdc4315f 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -1,11 +1,13 @@ import importlib -from functools import cache -@cache def load(strategy, tokenizer, cfg): try: - m = importlib.import_module(f".{strategy}", axolotl.prompt_strategies) - fn = getattr(m, "load") + load_fn = "load" + if strategy.split(".")[-1].startswith("load_"): + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") + fn = getattr(m, load_fn) return fn(tokenizer, cfg) except: pass diff --git a/src/axolotl/prompt_strategies/creative_acr.py b/src/axolotl/prompt_strategies/creative_acr.py new file mode 100644 index 000000000..e68e206de --- /dev/null +++ b/src/axolotl/prompt_strategies/creative_acr.py @@ -0,0 +1,137 @@ +from typing import Union, Generator + +import yaml +from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy + + +class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + question = prompt["instruction"] + answer = prompt["revision"] # don't use prompt[answer], that's data we don't want in the dataset + return ( + question, + "", + answer, + ) + + +class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria: +refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. +prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias. +creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity +comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth +Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria. + +Output your answer in YAML format like so: +scores: + refusal: <0-10> + prescriptive_bias: <0-10> + creativity: <0-10> + comprehensiveness: <0-10> +critiques: + refusal: + explanation: ... + improvements: ... + prescriptive_bias: + explanation: ... + improvements: ... + creativity: + explanation: ... + improvements: ... + comprehensiveness: + explanation: ... + improvements: ... + +Question: {question} +Answer: {answer} +""" + + def parse_instruction_fields(self, prompt) -> (str, str, str): + scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper) + critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper) + evaluation = scores + critiques + question = prompt["instruction"] + answer = prompt["answer"] + return ( + self.user_prompt.format(question=question, answer=answer), + "", + evaluation, + ) + + +class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + user_prompt = """Definitions: +refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. +prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias. +creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity +comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth + +Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response. + +Question: {question} +Answer: {answer} +Evaluation: +{evaluation} +""" + + def parse_instruction_fields(self, prompt) -> (str, str, str): + scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper) + critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper) + evaluation = scores + critiques + question = prompt["instruction"] + answer = prompt["answer"] + return ( + self.user_prompt.format(question=question, answer=answer, evaluation=evaluation), + "", + prompt["revision"], + ) + + +class CreativePrompterBase: + system_prompt = "" + prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:" + + def build_prompt( + self, + instruction: str, + input: Union[None, str] = None, + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + if self.system_prompt: + res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:" + else: + res = f"USER: {instruction}\nASSISTANT:" + if output: + res = f"{res}{output}" + yield res + + +class CreativeAnswerPrompter(CreativePrompterBase): + system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity." + + +class CreativeCritiquePrompter(CreativePrompterBase): + system_prompt = "" + + +class CreativeRevisePrompter(CreativePrompterBase): + system_prompt = "" + + +def load_answer(tokenizer, cfg): + return CreativeAnsweringPromptTokenizingStrategy( + CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + + +def load_critique(tokenizer, cfg): + return CreativeCritiquePromptTokenizingStrategy( + CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + + +def load_revise(tokenizer, cfg): + return CreativeRevisePromptTokenizingStrategy( + CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 3b6cbf0e3..bd70c73d5 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -41,9 +41,9 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): elif role == "bot": prefix = "<|model|>" res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True) - res["input_ids"] = [*self.bot_prefix_token_ids, *res["input_ids"]] # mask out the prefix token, rest is not masked out from labels - labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])] + # make sure we create the labels first, otherwise we get incorrect lengths + labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):] else: logging.warning(f"unknown role in conversation: {role}") res = defaultdict(lambda: []) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index a6e886138..f095cc9ab 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -75,7 +75,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa ds = None ds_from_hub = False try: - load_dataset(d.path, streaming=True) + load_dataset(d.path, streaming=True, use_auth_token=True) ds_from_hub = True except FileNotFoundError: pass @@ -83,18 +83,18 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa # prefer local dataset, even if hub exists if Path(d.path).exists(): ds: IterableDataset = load_dataset( - "json", data_files=d.path, streaming=True, split=None + "json", data_files=d.path, streaming=False, split=None ) elif ds_from_hub: if d.data_files: - ds = load_dataset(d.path, streaming=True, data_files=d.data_files) + ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True) else: - ds = load_dataset(d.path, streaming=True) + ds = load_dataset(d.path, streaming=False, use_auth_token=True) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds = load_dataset("json", data_files=fp, streaming=True, split=None) + ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise Exception("unhandled dataset load") d_type = d.type