From f4746507f62b6592a66683087b60f498add3111a Mon Sep 17 00:00:00 2001 From: TearGosling <119216654+TearGosling@users.noreply.github.com> Date: Mon, 21 Aug 2023 21:21:45 -0500 Subject: [PATCH] feat: add Metharme prompt strategy (#446) * Add Metharme tokenizing strategy This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length. I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now. * Redo Metharme tokenizing strategy lol * fix: oops * Rearrange a conditional * chore: reformat code in accordance with linter * chore: Make lint not freak out * chore: fix lint --------- Co-authored-by: NanoCode012 --- README.md | 4 ++ src/axolotl/prompt_strategies/metharme.py | 76 +++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 src/axolotl/prompt_strategies/metharme.py diff --git a/README.md b/README.md index 7229d0a70..20bde1786 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,10 @@ Have dataset(s) in one of the following format (JSONL recommended): ```json {"conversations": [{"role": "...", "value": "..."}]} ``` +- `metharme`: instruction, adds additional eos tokens + ```json + {"prompt": "...", "generation": "..."} + ``` - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from` ```json {"conversations": [{"role": "...", "value": "..."}]} diff --git a/src/axolotl/prompt_strategies/metharme.py b/src/axolotl/prompt_strategies/metharme.py new file mode 100644 index 000000000..52d77c00c --- /dev/null +++ b/src/axolotl/prompt_strategies/metharme.py @@ -0,0 +1,76 @@ +"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" + +import logging +from typing import Tuple + +from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter + +LOG = logging.getLogger("axolotl") + +IGNORE_TOKEN_ID = -100 + +# pylint: disable=duplicate-code + + +class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for the Metharme models + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return (prompt["prompt"], "", prompt["generation"]) + + def _tokenize( + self, + prompt: str, + add_eos_token: bool = True, + strip_bos_token: bool = False, + num_eos_tokens: int = 3, + ): + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) + if len(result["input_ids"]) == 0: + LOG.warning("Tokenizer result is empty. You may want to audit your dataset") + # If there's already an EOS token there, subtract from the number added + if result["input_ids"][-1] == self.tokenizer.eos_token_id: + num_eos_tokens -= 1 + + if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0: + for _ in range(num_eos_tokens): + if len(result["input_ids"]) < self.sequence_len: + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + + if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + + result["labels"] = result["input_ids"].copy() + return result + + +class MetharmePrompter(AlpacaPrompter): + """ + Prompter for the Metharme models. + """ + + system_prompt = "" + system_no_input_prompt = "" + system_format = "" + turn_format = "{instruction}" + turn_no_input_format = "{instruction}" + + def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called + pass + + +def load(tokenizer, cfg): + return MetharmePromptTokenizingStrategy( + MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + )