update prompts for open orca to match the paper (#317)
fix the test for the updated system tokenizer
This commit is contained in:
@@ -66,15 +66,34 @@ class SystemDataPrompter(AlpacaPrompter):
|
|||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
|
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
|
||||||
if input:
|
if input:
|
||||||
res = system + self.turn_format.format(instruction=instruction, input=input)
|
res = formatted_sys_prompt + self.turn_format.format(
|
||||||
|
instruction=instruction, input=input
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
res = system + self.turn_no_input_format.format(instruction=instruction)
|
res = formatted_sys_prompt + self.turn_no_input_format.format(
|
||||||
|
instruction=instruction
|
||||||
|
)
|
||||||
if output:
|
if output:
|
||||||
res = f"{res}{output}"
|
res = f"{res}{output}"
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
|
||||||
|
class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
||||||
|
"""
|
||||||
|
Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
|
||||||
|
"""
|
||||||
|
|
||||||
|
def match_prompt_style(self):
|
||||||
|
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||||
|
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||||
|
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||||
|
if self.prompt_style == PromptStyle.CHAT.value:
|
||||||
|
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||||
|
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||||
|
|
||||||
|
|
||||||
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
Tokenizing strategy for OpenOrca datasets
|
Tokenizing strategy for OpenOrca datasets
|
||||||
@@ -113,7 +132,7 @@ def load_chat(tokenizer, cfg):
|
|||||||
|
|
||||||
def load_open_orca(tokenizer, cfg):
|
def load_open_orca(tokenizer, cfg):
|
||||||
return OpenOrcaPromptTokenizingStrategy(
|
return OpenOrcaPromptTokenizingStrategy(
|
||||||
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -130,8 +130,9 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|||||||
"output": "Hi! How can I help?",
|
"output": "Hi! How can I help?",
|
||||||
}
|
}
|
||||||
example = strat.tokenize_prompt(sample)
|
example = strat.tokenize_prompt(sample)
|
||||||
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
|
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
|
||||||
assert example["input_ids"][3] == 11889 # USER
|
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
|
||||||
|
assert example["input_ids"][9] == 11889 # USER
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert "use cot" in res
|
assert "use cot" in res
|
||||||
assert res.startswith("use cot")
|
assert res.startswith("### System:")
|
||||||
assert "### Instruction:" not in res
|
assert "### Instruction:" not in res
|
||||||
assert "### Input:" not in res
|
assert "### Input:" not in res
|
||||||
assert "alpacas" in res
|
assert "alpacas" in res
|
||||||
|
|||||||
Reference in New Issue
Block a user