diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 5593a8dd3..8a23ee5a0 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -79,11 +79,13 @@ class ConstantLengthDataset(IterableDataset): buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer_len = 0 for dataset in self.datasets: + idx = 0 iterator = iter(dataset) more_examples = True while more_examples: try: example = next(iterator) + idx += 1 except StopIteration: more_examples = False example = None @@ -124,6 +126,7 @@ class ConstantLengthDataset(IterableDataset): "labels": [], } buffer_len = 0 + idx = 1 if example: # FIXME @@ -132,11 +135,6 @@ class ConstantLengthDataset(IterableDataset): input_ids = example["input_ids"] attention_mask = example["attention_mask"] labels = example["labels"] - if ( - buffer["input_ids"] - and input_ids[0] == self.tokenizer.bos_token_id - ): - attention_mask[0] = 0 if add_concat_token: input_ids.append(self.concat_token_id) @@ -147,7 +145,7 @@ class ConstantLengthDataset(IterableDataset): input_ids, dtype=self.tokens_dtype ) attention_mask_with_concat = torch.tensor( - attention_mask, dtype=self.tokens_dtype + [idx * m for m in attention_mask], dtype=torch.int16 ) labels_with_concat = torch.tensor( labels, dtype=self.tokens_dtype diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 1f19d0ecc..a0f476a42 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase): } ) - def test_resets_attention(self): + def test_increments_attention(self): prompter = AlpacaPrompter("chat") strat = AlpacaPromptTokenizingStrategy( prompter, @@ -58,7 +58,7 @@ class TestPacking(unittest.TestCase): # but subsequent one does assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id - assert example["attention_mask"][next_bos_index] == 0 + assert example["attention_mask"][next_bos_index] == 2 if __name__ == "__main__":