Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
8028652b8f fix attetion mask with packing
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-15 10:38:01 -04:00
3 changed files with 11 additions and 14 deletions

View File

@@ -79,11 +79,13 @@ class ConstantLengthDataset(IterableDataset):
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
@@ -124,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
"labels": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
@@ -132,11 +135,6 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token:
input_ids.append(self.concat_token_id)
@@ -147,7 +145,7 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype

View File

@@ -405,14 +405,13 @@ def load_prepare_datasets(
private=True,
)
else:
# dataset_train = load_tokenized_prepared_datasets(
dataset = load_tokenized_prepared_datasets(
dataset_train = load_tokenized_prepared_datasets(
"train", tokenizer, cfg, default_dataset_prepared_path
)
# dataset_test = load_tokenized_prepared_datasets(
# "test", tokenizer, cfg, default_dataset_prepared_path
# )
# dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
dataset_test = load_tokenized_prepared_datasets(
"test", tokenizer, cfg, default_dataset_prepared_path
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info(

View File

@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
}
)
def test_resets_attention(self):
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
@@ -58,7 +58,7 @@ class TestPacking(unittest.TestCase):
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 0
assert example["attention_mask"][next_bos_index] == 2
if __name__ == "__main__":