diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 0f8c31d6a..6161d7e37 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -20,11 +20,36 @@ def load(tokenizer, cfg): class AlpacaConcisePrompter(AlpacaPrompter): """ - Alpaca Prompter extending the system prompt to ask for concise answers + Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers """ - system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n" - system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n" + system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" + system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" + + +class AlpacaChatPrompter(AlpacaPrompter): + """ + Alpaca Chat Prompter extending the system prompt to for chat-instruct answers + """ + + system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" + system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" + + def __init__(self): # pylint: disable=super-init-not-called + self.prompt_style = PromptStyle.CHAT.value + self.match_prompt_style() + + +class NoSystemPrompter(AlpacaPrompter): + """ + Null Prompter with no system prompts + """ + + prompt_input = "{instruction} {input} " + prompt_no_input = "{instruction} " + + def __init__(self): # pylint: disable=super-init-not-called + pass class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): @@ -64,7 +89,7 @@ def load_concise(tokenizer, cfg): def load_qa(tokenizer, cfg): return AlpacaQAPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.CHAT.value), + AlpacaChatPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -73,7 +98,7 @@ def load_qa(tokenizer, cfg): def load_camel_ai(tokenizer, cfg): return CamelAIPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.CHAT.value), + AlpacaChatPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 8b3c88fee..6408620d7 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): input, # pylint: disable=redefined-builtin response, ) = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt(instruction, input, response) - tokenized_full_prompt = self._tokenize(full_prompt) - if not self.train_on_inputs: - user_prompt = next( - iter( - self.prompter.build_prompt( - instruction, - input, - ) + user_prompt = next( + iter( + self.prompter.build_prompt( + instruction, + input, ) ) - tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) - user_prompt_len = len(tokenized_user_prompt["input_ids"]) + ) + tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) + if not self.train_on_inputs: + user_prompt_len = len(tokenized_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing - tokenized_full_prompt["labels"] = [ - -100 - ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] + tokenized_prompt["labels"] = [-100] * user_prompt_len + tokenized_res_prompt = self._tokenize( + response, strip_bos_token=True, add_eos_token=True + ) + tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] + tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] + tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] - return tokenized_full_prompt + return tokenized_prompt def _build_full_prompt( self, instruction, input, response # pylint: disable=redefined-builtin diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 89209e84f..aba340eee 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -6,8 +6,12 @@ from pathlib import Path from transformers import AutoTokenizer -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompter +from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter +from axolotl.prompt_tokenizers import ( + AlpacaPromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, +) +from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter logging.basicConfig(level="INFO") @@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase): ) def test_sharegpt_integration(self): - print(Path(__file__).parent) with open( Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" ) as fin: @@ -53,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase): self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) self.assertEqual(example[fields], tokenized_conversation[fields]) + def test_no_sys_prompt(self): + """ + tests the interface between the user and assistant parts + """ + prompter = NoSystemPrompter() + # pylint: disable=duplicate-code + strat = AlpacaPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + sample = { + "instruction": "hello cruel. lorem ipsum dolor sit amet.", + "output": "world!", + } + example = strat.tokenize_prompt(sample) + world_idx = example["input_ids"].index(3186) + assert example["labels"][world_idx] == 3186 + assert example["labels"][world_idx - 1] == -100 + + def test_alpaca(self): + """ + tests the interface between the user and assistant parts + """ + # pylint: disable=duplicate-code + prompter = AlpacaPrompter() + strat = AlpacaPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + sample = {"instruction": "hello!", "output": "Hi! How can I help?"} + example = strat.tokenize_prompt(sample) + world_idx = example["input_ids"].index(6324) + assert example["labels"][world_idx] == 6324 + assert example["labels"][world_idx - 1] == -100 + if __name__ == "__main__": unittest.main()