diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 1183c1e8e..6161d7e37 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -40,6 +40,18 @@ class AlpacaChatPrompter(AlpacaPrompter): self.match_prompt_style() +class NoSystemPrompter(AlpacaPrompter): + """ + Null Prompter with no system prompts + """ + + prompt_input = "{instruction} {input} " + prompt_no_input = "{instruction} " + + def __init__(self): # pylint: disable=super-init-not-called + pass + + class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): """ Tokenizing strategy for AlpacaQA diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 8b3c88fee..6408620d7 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): input, # pylint: disable=redefined-builtin response, ) = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt(instruction, input, response) - tokenized_full_prompt = self._tokenize(full_prompt) - if not self.train_on_inputs: - user_prompt = next( - iter( - self.prompter.build_prompt( - instruction, - input, - ) + user_prompt = next( + iter( + self.prompter.build_prompt( + instruction, + input, ) ) - tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) - user_prompt_len = len(tokenized_user_prompt["input_ids"]) + ) + tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) + if not self.train_on_inputs: + user_prompt_len = len(tokenized_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing - tokenized_full_prompt["labels"] = [ - -100 - ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] + tokenized_prompt["labels"] = [-100] * user_prompt_len + tokenized_res_prompt = self._tokenize( + response, strip_bos_token=True, add_eos_token=True + ) + tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] + tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] + tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] - return tokenized_full_prompt + return tokenized_prompt def _build_full_prompt( self, instruction, input, response # pylint: disable=redefined-builtin diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 89209e84f..abc746bbf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -6,8 +6,12 @@ from pathlib import Path from transformers import AutoTokenizer -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompter +from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter +from axolotl.prompt_tokenizers import ( + AlpacaPromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, +) +from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter logging.basicConfig(level="INFO") @@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase): ) def test_sharegpt_integration(self): - print(Path(__file__).parent) with open( Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" ) as fin: @@ -53,6 +56,43 @@ class TestPromptTokenizationStrategies(unittest.TestCase): self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) self.assertEqual(example[fields], tokenized_conversation[fields]) + def test_completion(self): + """ + tests the interface between the user and assistant parts + """ + prompter = NoSystemPrompter() + strat = AlpacaPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + sample = { + "instruction": "hello cruel. lorem ipsum dolor sit amet.", + "output": "world!", + } + example = strat.tokenize_prompt(sample) + world_idx = example["input_ids"].index(3186) + assert example["labels"][world_idx] == 3186 + assert example["labels"][world_idx - 1] == -100 + + def test_alpaca(self): + """ + tests the interface between the user and assistant parts + """ + prompter = AlpacaPrompter() + strat = AlpacaPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + sample = {"instruction": "hello!", "output": "Hi! How can I help?"} + example = strat.tokenize_prompt(sample) + world_idx = example["input_ids"].index(6324) + assert example["labels"][world_idx] == 6324 + assert example["labels"][world_idx - 1] == -100 + if __name__ == "__main__": unittest.main()