diff --git a/README.md b/README.md index 7229d0a70..20bde1786 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,10 @@ Have dataset(s) in one of the following format (JSONL recommended): ```json {"conversations": [{"role": "...", "value": "..."}]} ``` +- `metharme`: instruction, adds additional eos tokens + ```json + {"prompt": "...", "generation": "..."} + ``` - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from` ```json {"conversations": [{"role": "...", "value": "..."}]} diff --git a/src/axolotl/prompt_strategies/metharme.py b/src/axolotl/prompt_strategies/metharme.py new file mode 100644 index 000000000..52d77c00c --- /dev/null +++ b/src/axolotl/prompt_strategies/metharme.py @@ -0,0 +1,76 @@ +"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" + +import logging +from typing import Tuple + +from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter + +LOG = logging.getLogger("axolotl") + +IGNORE_TOKEN_ID = -100 + +# pylint: disable=duplicate-code + + +class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for the Metharme models + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return (prompt["prompt"], "", prompt["generation"]) + + def _tokenize( + self, + prompt: str, + add_eos_token: bool = True, + strip_bos_token: bool = False, + num_eos_tokens: int = 3, + ): + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) + if len(result["input_ids"]) == 0: + LOG.warning("Tokenizer result is empty. You may want to audit your dataset") + # If there's already an EOS token there, subtract from the number added + if result["input_ids"][-1] == self.tokenizer.eos_token_id: + num_eos_tokens -= 1 + + if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0: + for _ in range(num_eos_tokens): + if len(result["input_ids"]) < self.sequence_len: + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + + if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + + result["labels"] = result["input_ids"].copy() + return result + + +class MetharmePrompter(AlpacaPrompter): + """ + Prompter for the Metharme models. + """ + + system_prompt = "" + system_no_input_prompt = "" + system_format = "" + turn_format = "{instruction}" + turn_no_input_format = "{instruction}" + + def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called + pass + + +def load(tokenizer, cfg): + return MetharmePromptTokenizingStrategy( + MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + )