fix packing so that concatenated sequences reset the attention

This commit is contained in:
Wing Lian
2023-05-31 11:38:52 -04:00
parent 8eb5811d4e
commit 9b8585dc70
3 changed files with 81 additions and 0 deletions

View File

@@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token:
input_ids.append(self.concat_token_id)