bugfix for potential off by one
This commit is contained in:
@@ -40,6 +40,18 @@ class AlpacaChatPrompter(AlpacaPrompter):
|
||||
self.match_prompt_style()
|
||||
|
||||
|
||||
class NoSystemPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Null Prompter with no system prompts
|
||||
"""
|
||||
|
||||
prompt_input = "{instruction} {input} "
|
||||
prompt_no_input = "{instruction} "
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
pass
|
||||
|
||||
|
||||
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for AlpacaQA
|
||||
|
||||
@@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_full_prompt
|
||||
return tokenized_prompt
|
||||
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response # pylint: disable=redefined-builtin
|
||||
|
||||
Reference in New Issue
Block a user