Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
f6721baf10 tweak to make it work when we have no explicit test split
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-11 22:40:21 -04:00
3 changed files with 14 additions and 11 deletions

View File

@@ -79,13 +79,11 @@ class ConstantLengthDataset(IterableDataset):
buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0 buffer_len = 0
for dataset in self.datasets: for dataset in self.datasets:
idx = 0
iterator = iter(dataset) iterator = iter(dataset)
more_examples = True more_examples = True
while more_examples: while more_examples:
try: try:
example = next(iterator) example = next(iterator)
idx += 1
except StopIteration: except StopIteration:
more_examples = False more_examples = False
example = None example = None
@@ -126,7 +124,6 @@ class ConstantLengthDataset(IterableDataset):
"labels": [], "labels": [],
} }
buffer_len = 0 buffer_len = 0
idx = 1
if example: if example:
# FIXME # FIXME
@@ -135,6 +132,11 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"] input_ids = example["input_ids"]
attention_mask = example["attention_mask"] attention_mask = example["attention_mask"]
labels = example["labels"] labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token: if add_concat_token:
input_ids.append(self.concat_token_id) input_ids.append(self.concat_token_id)
@@ -145,7 +147,7 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype input_ids, dtype=self.tokens_dtype
) )
attention_mask_with_concat = torch.tensor( attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16 attention_mask, dtype=self.tokens_dtype
) )
labels_with_concat = torch.tensor( labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype labels, dtype=self.tokens_dtype

View File

@@ -405,13 +405,14 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
dataset_train = load_tokenized_prepared_datasets( # dataset_train = load_tokenized_prepared_datasets(
dataset = load_tokenized_prepared_datasets(
"train", tokenizer, cfg, default_dataset_prepared_path "train", tokenizer, cfg, default_dataset_prepared_path
) )
dataset_test = load_tokenized_prepared_datasets( # dataset_test = load_tokenized_prepared_datasets(
"test", tokenizer, cfg, default_dataset_prepared_path # "test", tokenizer, cfg, default_dataset_prepared_path
) # )
dataset = DatasetDict({"train": dataset_train, "test": dataset_test}) # dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info( logging.info(

View File

@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
} }
) )
def test_increments_attention(self): def test_resets_attention(self):
prompter = AlpacaPrompter("chat") prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy( strat = AlpacaPromptTokenizingStrategy(
prompter, prompter,
@@ -58,7 +58,7 @@ class TestPacking(unittest.TestCase):
# but subsequent one does # but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2 assert example["attention_mask"][next_bos_index] == 0
if __name__ == "__main__": if __name__ == "__main__":