From 9b8585dc70ceca667f92ea46f159b3d711055008 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 11:38:52 -0400 Subject: [PATCH 1/2] fix packing so that concatenated sequences reset the attention --- src/axolotl/datasets.py | 5 +++ tests/fixtures/alpaca/alpaca.json | 12 ++++++ tests/test_packed_dataset.py | 64 +++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 tests/fixtures/alpaca/alpaca.json create mode 100644 tests/test_packed_dataset.py diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index fb5e15656..d6367ce7c 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset): input_ids = example["input_ids"] attention_mask = example["attention_mask"] labels = example["labels"] + if ( + buffer["input_ids"] + and input_ids[0] == self.tokenizer.bos_token_id + ): + attention_mask[0] = 0 if add_concat_token: input_ids.append(self.concat_token_id) diff --git a/tests/fixtures/alpaca/alpaca.json b/tests/fixtures/alpaca/alpaca.json new file mode 100644 index 000000000..912643d17 --- /dev/null +++ b/tests/fixtures/alpaca/alpaca.json @@ -0,0 +1,12 @@ +[ + { + "instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.", + "input": "Words: ['Hello', 'world'].", + "output": "['world', 'Hello']" + }, + { + "instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.", + "input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.", + "output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar." + } +] diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py new file mode 100644 index 000000000..ced0360fe --- /dev/null +++ b/tests/test_packed_dataset.py @@ -0,0 +1,64 @@ +"""Module for testing dataset sequence packing""" + +import unittest +from pathlib import Path + +from datasets import Dataset, load_dataset +from transformers import AutoTokenizer + +from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter + + +class TestPacking(unittest.TestCase): + """ + Test class for packing dataset sequences + """ + + def setUp(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + } + ) + + def test_resets_attention(self): + prompter = AlpacaPrompter("chat") + strat = AlpacaPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + dateset = load_dataset( + "json", + data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"), + )["train"] + dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset))) + + constant_len_dataset = ConstantLengthDataset( + self.tokenizer, + [dataset], + seq_length=2048, + ) + packed_dataset = Dataset.from_list(list(constant_len_dataset)) + example = packed_dataset[0] + next_bos_index = ( + example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1 + ) # add one since we sliced + + # first example doesn't have mask reset + assert example["input_ids"][0] == self.tokenizer.bos_token_id + assert example["attention_mask"][0] == 1 + + # but subsequent one does + assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id + assert example["attention_mask"][next_bos_index] == 0 + + +if __name__ == "__main__": + unittest.main() From 0136f510f2dcbf039a20e1fdd6c5c256016f6390 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 12:05:43 -0400 Subject: [PATCH 2/2] don't worry about duplicate code here --- tests/test_packed_dataset.py | 1 + tests/test_prompt_tokenizers.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index ced0360fe..1f19d0ecc 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -17,6 +17,7 @@ class TestPacking(unittest.TestCase): """ def setUp(self) -> None: + # pylint: disable=duplicate-code self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.add_special_tokens( { diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index fa85fe5f6..89209e84f 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -18,6 +18,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase): """ def setUp(self) -> None: + # pylint: disable=duplicate-code self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.add_special_tokens( {