Merge pull request #131 from OpenAccess-AI-Collective/fix-packing-mask

fix packing so that concatenated sequences reset the attention
This commit is contained in:
Wing Lian
2023-05-31 13:03:37 -04:00
committed by GitHub
4 changed files with 83 additions and 0 deletions

View File

@@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token:
input_ids.append(self.concat_token_id)

12
tests/fixtures/alpaca/alpaca.json vendored Normal file
View File

@@ -0,0 +1,12 @@
[
{
"instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
"input": "Words: ['Hello', 'world'].",
"output": "['world', 'Hello']"
},
{
"instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
"input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
"output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
}
]

View File

@@ -0,0 +1,65 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
class TestPacking(unittest.TestCase):
"""
Test class for packing dataset sequences
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_resets_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 0
if __name__ == "__main__":
unittest.main()

View File

@@ -18,6 +18,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{