fix attetion mask with packing

This commit is contained in:
Wing Lian
2023-07-15 10:38:01 -04:00
parent 176b888a63
commit 36b0e30a9d
2 changed files with 6 additions and 8 deletions

View File

@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
}
)
def test_resets_attention(self):
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
@@ -58,7 +58,7 @@ class TestPacking(unittest.TestCase):
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 0
assert example["attention_mask"][next_bos_index] == 2
if __name__ == "__main__":