Compare commits
2 Commits
jagged-res
...
no-bos-tok
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05d19d2037 | ||
|
|
61f44f311e |
@@ -132,8 +132,12 @@ class ConstantLengthDataset(IterableDataset):
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
if (
|
||||
buffer["input_ids"]
|
||||
and input_ids[0] == self.tokenizer.bos_token_id
|
||||
(
|
||||
buffer["input_ids"]
|
||||
and input_ids[0] == self.tokenizer.bos_token_id
|
||||
)
|
||||
or self.tokenizer.bos_token_id
|
||||
== self.tokenizer.eos_token_id
|
||||
):
|
||||
attention_mask[0] = 0
|
||||
|
||||
|
||||
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
|
||||
|
||||
logging.info(" ".join(colored_tokens))
|
||||
logging.info("\n\n\n")
|
||||
|
||||
print(" ".join(colored_tokens))
|
||||
|
||||
@@ -11,7 +11,57 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
|
||||
class TestPacking(unittest.TestCase):
|
||||
class TestGpt2Packing(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing dataset sequences
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<|endoftext|>",
|
||||
"eos_token": "<|endoftext|>",
|
||||
"unk_token": "<|endoftext|>",
|
||||
}
|
||||
)
|
||||
self.tokenizer.bos_token_id = 0
|
||||
self.tokenizer.eos_token_id = 0
|
||||
self.tokenizer.unk_token_id = 0
|
||||
|
||||
def test_resets_attention(self):
|
||||
prompter = AlpacaPrompter("chat")
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
dateset = load_dataset(
|
||||
"json",
|
||||
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
||||
)["train"]
|
||||
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
||||
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
self.tokenizer,
|
||||
[dataset],
|
||||
seq_length=2048,
|
||||
)
|
||||
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
|
||||
example = packed_dataset[0]
|
||||
# tokenizers where eos and bos tokens are the same, don't have a bos token
|
||||
next_eos_index = (
|
||||
example["input_ids"][1:].index(self.tokenizer.eos_token_id) + 1
|
||||
) # add one since we sliced
|
||||
|
||||
assert example["input_ids"][next_eos_index] == self.tokenizer.eos_token_id
|
||||
assert example["attention_mask"][next_eos_index + 1] == 0
|
||||
|
||||
|
||||
class TestLlamaPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing dataset sequences
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user