Compare commits
1 Commits
openorca
...
openorca-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8028652b8f |
@@ -79,11 +79,13 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
|
idx = 0
|
||||||
iterator = iter(dataset)
|
iterator = iter(dataset)
|
||||||
more_examples = True
|
more_examples = True
|
||||||
while more_examples:
|
while more_examples:
|
||||||
try:
|
try:
|
||||||
example = next(iterator)
|
example = next(iterator)
|
||||||
|
idx += 1
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
more_examples = False
|
more_examples = False
|
||||||
example = None
|
example = None
|
||||||
@@ -124,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"labels": [],
|
"labels": [],
|
||||||
}
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
|
idx = 1
|
||||||
|
|
||||||
if example:
|
if example:
|
||||||
# FIXME
|
# FIXME
|
||||||
@@ -132,11 +135,6 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
input_ids = example["input_ids"]
|
input_ids = example["input_ids"]
|
||||||
attention_mask = example["attention_mask"]
|
attention_mask = example["attention_mask"]
|
||||||
labels = example["labels"]
|
labels = example["labels"]
|
||||||
if (
|
|
||||||
buffer["input_ids"]
|
|
||||||
and input_ids[0] == self.tokenizer.bos_token_id
|
|
||||||
):
|
|
||||||
attention_mask[0] = 0
|
|
||||||
|
|
||||||
if add_concat_token:
|
if add_concat_token:
|
||||||
input_ids.append(self.concat_token_id)
|
input_ids.append(self.concat_token_id)
|
||||||
@@ -147,7 +145,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
input_ids, dtype=self.tokens_dtype
|
input_ids, dtype=self.tokens_dtype
|
||||||
)
|
)
|
||||||
attention_mask_with_concat = torch.tensor(
|
attention_mask_with_concat = torch.tensor(
|
||||||
attention_mask, dtype=self.tokens_dtype
|
[idx * m for m in attention_mask], dtype=torch.int16
|
||||||
)
|
)
|
||||||
labels_with_concat = torch.tensor(
|
labels_with_concat = torch.tensor(
|
||||||
labels, dtype=self.tokens_dtype
|
labels, dtype=self.tokens_dtype
|
||||||
|
|||||||
@@ -405,14 +405,13 @@ def load_prepare_datasets(
|
|||||||
private=True,
|
private=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# dataset_train = load_tokenized_prepared_datasets(
|
dataset_train = load_tokenized_prepared_datasets(
|
||||||
dataset = load_tokenized_prepared_datasets(
|
|
||||||
"train", tokenizer, cfg, default_dataset_prepared_path
|
"train", tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
# dataset_test = load_tokenized_prepared_datasets(
|
dataset_test = load_tokenized_prepared_datasets(
|
||||||
# "test", tokenizer, cfg, default_dataset_prepared_path
|
"test", tokenizer, cfg, default_dataset_prepared_path
|
||||||
# )
|
)
|
||||||
# dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
|
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
|
||||||
|
|
||||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||||
logging.info(
|
logging.info(
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_resets_attention(self):
|
def test_increments_attention(self):
|
||||||
prompter = AlpacaPrompter("chat")
|
prompter = AlpacaPrompter("chat")
|
||||||
strat = AlpacaPromptTokenizingStrategy(
|
strat = AlpacaPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
@@ -58,7 +58,7 @@ class TestPacking(unittest.TestCase):
|
|||||||
|
|
||||||
# but subsequent one does
|
# but subsequent one does
|
||||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||||
assert example["attention_mask"][next_bos_index] == 0
|
assert example["attention_mask"][next_bos_index] == 2
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user