Refactor duplicate code between Prompter and Pygmalion
This commit is contained in:
@@ -5,7 +5,11 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import (
|
||||||
|
PromptTokenizingStrategy,
|
||||||
|
parse_tokenized_to_result,
|
||||||
|
tokenize_prompt_default,
|
||||||
|
)
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
@@ -23,12 +27,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
self.bot_prefix_token_ids = res["input_ids"]
|
self.bot_prefix_token_ids = res["input_ids"]
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
result = {
|
result, current_len = tokenize_prompt_default()
|
||||||
"input_ids": [],
|
|
||||||
"attention_mask": [],
|
|
||||||
"labels": [],
|
|
||||||
}
|
|
||||||
current_len = 0
|
|
||||||
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
||||||
role, message = part
|
role, message = part
|
||||||
if role == "system":
|
if role == "system":
|
||||||
@@ -67,37 +66,15 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
logging.warning(f"unknown role in conversation: {role}")
|
logging.warning(f"unknown role in conversation: {role}")
|
||||||
res = defaultdict(lambda: [])
|
res = defaultdict(lambda: [])
|
||||||
input_ids = res["input_ids"]
|
|
||||||
input_len = len(input_ids)
|
|
||||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
|
||||||
result["attention_mask"][current_len : current_len + input_len] = [
|
|
||||||
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
|
||||||
]
|
|
||||||
result["labels"][current_len : current_len + input_len] = labels
|
|
||||||
current_len += input_len
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
# pylint: disable=duplicate-code
|
||||||
result = self.tokenizer(
|
result, current_len = parse_tokenized_to_result(
|
||||||
prompt,
|
result,
|
||||||
truncation=True,
|
current_len,
|
||||||
max_length=self.sequence_len,
|
res,
|
||||||
padding=False,
|
labels,
|
||||||
return_tensors=None,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
||||||
and len(result["input_ids"]) < self.sequence_len
|
|
||||||
and add_eos_token
|
|
||||||
):
|
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
|
||||||
result["attention_mask"].append(1)
|
|
||||||
|
|
||||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
|
||||||
|
|
||||||
result["labels"] = result["input_ids"].copy()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import abc
|
|||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@@ -58,6 +58,29 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
return id_or_ids
|
return id_or_ids
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
|
||||||
|
result = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.sequence_len,
|
||||||
|
padding=False,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||||
|
and len(result["input_ids"]) < self.sequence_len
|
||||||
|
and add_eos_token
|
||||||
|
):
|
||||||
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||||
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
|
result["labels"] = result["input_ids"].copy()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -106,29 +129,6 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
|
||||||
result = self.tokenizer(
|
|
||||||
prompt,
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.sequence_len,
|
|
||||||
padding=False,
|
|
||||||
return_tensors=None,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
||||||
and len(result["input_ids"]) < self.sequence_len
|
|
||||||
and add_eos_token
|
|
||||||
):
|
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
|
||||||
result["attention_mask"].append(1)
|
|
||||||
|
|
||||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
|
||||||
|
|
||||||
result["labels"] = result["input_ids"].copy()
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -295,7 +295,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True):
|
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@@ -339,12 +339,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
return prompt["conversations"]
|
return prompt["conversations"]
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
result = {
|
result, current_len = tokenize_prompt_default()
|
||||||
"input_ids": [],
|
|
||||||
"attention_mask": [],
|
|
||||||
"labels": [],
|
|
||||||
}
|
|
||||||
current_len = 0
|
|
||||||
user_token = self._get_user_token()
|
user_token = self._get_user_token()
|
||||||
assistant_token = self._get_assistant_token()
|
assistant_token = self._get_assistant_token()
|
||||||
try:
|
try:
|
||||||
@@ -382,14 +377,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
)
|
)
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
input_ids = res["input_ids"]
|
|
||||||
input_len = len(input_ids)
|
# pylint: disable=duplicate-code
|
||||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
result, current_len = parse_tokenized_to_result(
|
||||||
result["attention_mask"][current_len : current_len + input_len] = [
|
result,
|
||||||
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
current_len,
|
||||||
]
|
res,
|
||||||
result["labels"][current_len : current_len + input_len] = labels
|
labels,
|
||||||
current_len += input_len
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
except (KeyError, AssertionError, IndexError) as err:
|
except (KeyError, AssertionError, IndexError) as err:
|
||||||
raise InvalidDataException(str(err)) from err
|
raise InvalidDataException(str(err)) from err
|
||||||
@@ -416,3 +412,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
result["labels"] = result["input_ids"].copy()
|
result["labels"] = result["input_ids"].copy()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||||
|
"""
|
||||||
|
Returns the default values for the tokenize prompt function
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"input_ids": [],
|
||||||
|
"attention_mask": [],
|
||||||
|
"labels": [],
|
||||||
|
}
|
||||||
|
current_len = 0
|
||||||
|
return result, current_len
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tokenized_to_result(
|
||||||
|
result: Dict[str, List[int]],
|
||||||
|
current_len: int,
|
||||||
|
res: Dict[str, List[int]],
|
||||||
|
labels: list[int],
|
||||||
|
pad_token_id: int | None = None,
|
||||||
|
) -> Tuple[Dict[str, List[int]], int]:
|
||||||
|
"""
|
||||||
|
Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
input_len = len(input_ids)
|
||||||
|
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||||
|
result["attention_mask"][current_len : current_len + input_len] = [
|
||||||
|
1 if x != pad_token_id else 0 for x in input_ids
|
||||||
|
]
|
||||||
|
result["labels"][current_len : current_len + input_len] = labels
|
||||||
|
current_len += input_len
|
||||||
|
|
||||||
|
return result, current_len
|
||||||
|
|||||||
Reference in New Issue
Block a user