From 61f44f311eb0163d001fae2d8ac2042c951c1847 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 13 Jun 2023 21:26:13 -0400 Subject: [PATCH] fix packing for tokenizers that don't use a bos_token when the bos token and eos token are both the same --- src/axolotl/datasets.py | 8 +++-- src/axolotl/utils/tokenization.py | 2 ++ tests/test_packed_dataset.py | 60 ++++++++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 40c58bc9c..f2fd1c5c9 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -132,8 +132,12 @@ class ConstantLengthDataset(IterableDataset): attention_mask = example["attention_mask"] labels = example["labels"] if ( - buffer["input_ids"] - and input_ids[0] == self.tokenizer.bos_token_id + ( + buffer["input_ids"] + and input_ids[0] == self.tokenizer.bos_token_id + ) + or self.tokenizer.bos_token_id + == self.tokenizer.eos_token_id ): attention_mask[0] = 0 diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 1c535eb1b..cd49d7218 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer): logging.info(" ".join(colored_tokens)) logging.info("\n\n\n") + + print(" ".join(colored_tokens)) diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 1f19d0ecc..65bb2eb60 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -11,7 +11,62 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter -class TestPacking(unittest.TestCase): +class TestGpt2Packing(unittest.TestCase): + """ + Test class for packing dataset sequences + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderplus") + self.tokenizer.add_special_tokens( + { + "bos_token": "<|endoftext|>", + "eos_token": "<|endoftext|>", + "unk_token": "<|endoftext|>", + } + ) + self.tokenizer.bos_token_id = 0 + self.tokenizer.eos_token_id = 0 + self.tokenizer.unk_token_id = 0 + + def test_resets_attention(self): + prompter = AlpacaPrompter("chat") + strat = AlpacaPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + dateset = load_dataset( + "json", + data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"), + )["train"] + dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset))) + + constant_len_dataset = ConstantLengthDataset( + self.tokenizer, + [dataset], + seq_length=2048, + ) + packed_dataset = Dataset.from_list(list(constant_len_dataset)) + + example = packed_dataset[0] + from axolotl.utils.tokenization import check_example_labels + + check_example_labels(example, self.tokenizer) + # tokenizers where eos and bos tokens are the same, don't have a bos token + next_eos_index = ( + example["input_ids"][1:].index(self.tokenizer.eos_token_id) + 1 + ) # add one since we sliced + + print(example["input_ids"][next_eos_index + 1]) + + assert example["input_ids"][next_eos_index] == self.tokenizer.eos_token_id + assert example["attention_mask"][next_eos_index + 1] == 0 + + +class TestLlamaPacking(unittest.TestCase): """ Test class for packing dataset sequences """ @@ -48,6 +103,9 @@ class TestPacking(unittest.TestCase): ) packed_dataset = Dataset.from_list(list(constant_len_dataset)) example = packed_dataset[0] + from axolotl.utils.tokenization import check_example_labels + + check_example_labels(example, self.tokenizer) next_bos_index = ( example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1 ) # add one since we sliced