split completion text to sequence_len (#616)
This commit is contained in:
@@ -38,10 +38,15 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, os.cpu_count())
|
num_proc = min(64, os.cpu_count())
|
||||||
|
map_kwargs = {}
|
||||||
|
if self.prompt_tokenizer.supports_batched:
|
||||||
|
map_kwargs["batched"] = True
|
||||||
|
map_kwargs["batch_size"] = 100
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=num_proc,
|
num_proc=num_proc,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,81 @@
|
|||||||
"""
|
"""
|
||||||
Basic completion text
|
Basic completion text
|
||||||
"""
|
"""
|
||||||
from typing import Any, Dict, Optional
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||||
from axolotl.prompters import CompletionPrompter
|
|
||||||
|
|
||||||
|
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
Tokenizing strategy for Completion prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_field: str = "text"
|
||||||
|
|
||||||
|
def __init__(self, *args, max_length=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if max_length is not None:
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_batched(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field(self) -> str:
|
||||||
|
return self._field
|
||||||
|
|
||||||
|
@field.setter
|
||||||
|
def field(self, new_field: str):
|
||||||
|
self._field = new_field
|
||||||
|
|
||||||
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
|
return (
|
||||||
|
prompt[self.field],
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize_prompt(self, prompt):
|
||||||
|
res = defaultdict(lambda: [])
|
||||||
|
feature_names = list(prompt.keys())
|
||||||
|
for row in zip(*prompt.values()):
|
||||||
|
prompt_row = dict(zip(feature_names, row))
|
||||||
|
(
|
||||||
|
instruction,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self.parse_instruction_fields(prompt_row)
|
||||||
|
|
||||||
|
full_prompt = self._build_full_prompt(instruction, None, None)
|
||||||
|
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||||
|
|
||||||
|
for key, val in tokenized_full_prompt.items():
|
||||||
|
for i in range(0, len(val), self.sequence_len):
|
||||||
|
res[key].append(val[i : i + self.sequence_len])
|
||||||
|
|
||||||
|
return dict(res)
|
||||||
|
|
||||||
|
def _build_full_prompt(
|
||||||
|
self, instruction, input, response
|
||||||
|
): # pylint: disable=redefined-builtin
|
||||||
|
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionPrompter:
|
||||||
|
"""
|
||||||
|
Prompter for completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input=None, # pylint: disable=redefined-builtin, unused-argument
|
||||||
|
output=None, # pylint: disable=unused-argument
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
yield instruction
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
@@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
max_length=cfg.sequence_len * 64,
|
||||||
)
|
)
|
||||||
if ds_cfg and "field" in ds_cfg:
|
if ds_cfg and "field" in ds_cfg:
|
||||||
strat.field = ds_cfg["field"]
|
strat.field = ds_cfg["field"]
|
||||||
|
|||||||
@@ -41,11 +41,16 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||||
self.train_on_inputs = train_on_inputs
|
self.train_on_inputs = train_on_inputs
|
||||||
self.sequence_len = sequence_len
|
self.sequence_len = sequence_len
|
||||||
|
self.max_length = sequence_len
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_batched(self):
|
||||||
|
return False
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=128)
|
@functools.lru_cache(maxsize=128)
|
||||||
def _get_user_token(self):
|
def _get_user_token(self):
|
||||||
try:
|
try:
|
||||||
@@ -77,7 +82,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.sequence_len,
|
max_length=self.max_length,
|
||||||
padding=False,
|
padding=False,
|
||||||
return_tensors=None,
|
return_tensors=None,
|
||||||
)
|
)
|
||||||
@@ -86,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
if (
|
if (
|
||||||
len(result["input_ids"]) > 0
|
len(result["input_ids"]) > 0
|
||||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||||
and len(result["input_ids"]) < self.sequence_len
|
and len(result["input_ids"]) < self.max_length
|
||||||
and add_eos_token
|
and add_eos_token
|
||||||
):
|
):
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
@@ -247,46 +252,6 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
Tokenizing strategy for Completion prompts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_field: str = "text"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def field(self) -> str:
|
|
||||||
return self._field
|
|
||||||
|
|
||||||
@field.setter
|
|
||||||
def field(self, new_field: str):
|
|
||||||
self._field = new_field
|
|
||||||
|
|
||||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
|
||||||
return (
|
|
||||||
prompt[self.field],
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
|
||||||
(
|
|
||||||
instruction,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = self.parse_instruction_fields(prompt)
|
|
||||||
|
|
||||||
full_prompt = self._build_full_prompt(instruction, None, None)
|
|
||||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
|
||||||
|
|
||||||
return tokenized_full_prompt
|
|
||||||
|
|
||||||
def _build_full_prompt(
|
|
||||||
self, instruction, input, response
|
|
||||||
): # pylint: disable=redefined-builtin
|
|
||||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
|
||||||
|
|
||||||
|
|
||||||
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
Tokenizing strategy for Reflection prompts.
|
Tokenizing strategy for Reflection prompts.
|
||||||
|
|||||||
@@ -135,20 +135,6 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
|
|||||||
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||||
|
|
||||||
|
|
||||||
class CompletionPrompter:
|
|
||||||
"""
|
|
||||||
Prompter for completion
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build_prompt(
|
|
||||||
self,
|
|
||||||
instruction: str,
|
|
||||||
input=None, # pylint: disable=redefined-builtin, unused-argument
|
|
||||||
output=None, # pylint: disable=unused-argument
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
yield instruction
|
|
||||||
|
|
||||||
|
|
||||||
class GPTeacherPrompter(AlpacaPrompter):
|
class GPTeacherPrompter(AlpacaPrompter):
|
||||||
"""
|
"""
|
||||||
Prompter for GPTeacher
|
Prompter for GPTeacher
|
||||||
|
|||||||
Reference in New Issue
Block a user