feat: add Metharme prompt strategy (#446)
* Add Metharme tokenizing strategy This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length. I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now. * Redo Metharme tokenizing strategy lol * fix: oops * Rearrange a conditional * chore: reformat code in accordance with linter * chore: Make lint not freak out * chore: fix lint --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -257,6 +257,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
- `metharme`: instruction, adds additional eos tokens
|
||||||
|
```json
|
||||||
|
{"prompt": "...", "generation": "..."}
|
||||||
|
```
|
||||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
|
|||||||
76
src/axolotl/prompt_strategies/metharme.py
Normal file
76
src/axolotl/prompt_strategies/metharme.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||||
|
from axolotl.prompters import AlpacaPrompter
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
Tokenizing strategy for the Metharme models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
|
return (prompt["prompt"], "", prompt["generation"])
|
||||||
|
|
||||||
|
def _tokenize(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
add_eos_token: bool = True,
|
||||||
|
strip_bos_token: bool = False,
|
||||||
|
num_eos_tokens: int = 3,
|
||||||
|
):
|
||||||
|
result = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.sequence_len,
|
||||||
|
padding=False,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
if len(result["input_ids"]) == 0:
|
||||||
|
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||||
|
# If there's already an EOS token there, subtract from the number added
|
||||||
|
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
|
||||||
|
num_eos_tokens -= 1
|
||||||
|
|
||||||
|
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
|
||||||
|
for _ in range(num_eos_tokens):
|
||||||
|
if len(result["input_ids"]) < self.sequence_len:
|
||||||
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||||
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
|
result["labels"] = result["input_ids"].copy()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MetharmePrompter(AlpacaPrompter):
|
||||||
|
"""
|
||||||
|
Prompter for the Metharme models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt = ""
|
||||||
|
system_no_input_prompt = ""
|
||||||
|
system_format = ""
|
||||||
|
turn_format = "{instruction}"
|
||||||
|
turn_no_input_format = "{instruction}"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg):
|
||||||
|
return MetharmePromptTokenizingStrategy(
|
||||||
|
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user