add tests and supoort for loader for sys prompt data
This commit is contained in:
83
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
83
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Prompt strategies loader for alpaca instruction datasets with system prompts
|
||||
"""
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
prompt["output"],
|
||||
prompt["system"],
|
||||
)
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
system,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt_w_system(
|
||||
system,
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class SystemDataPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = system + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = system + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return InstructionWSystemPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
@@ -63,29 +63,6 @@ class AlpacaPrompter:
|
||||
yield res
|
||||
|
||||
|
||||
class SystemDataPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = system + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = system + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for alpaca no system prompt
|
||||
|
||||
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
|
||||
|
||||
logging.info(" ".join(colored_tokens))
|
||||
logging.info("\n\n\n")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
Reference in New Issue
Block a user