fix attetion mask with packing
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Wing Lian
2023-07-15 10:38:01 -04:00
parent 33814cc94e
commit 8028652b8f
2 changed files with 6 additions and 8 deletions

View File

@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
}
)
def test_resets_attention(self):
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
@@ -58,7 +58,7 @@ class TestPacking(unittest.TestCase):
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 0
assert example["attention_mask"][next_bos_index] == 2
if __name__ == "__main__":