move filter to before saving so it doesn't happen everytime, update runpod manual script

This commit is contained in:
Wing Lian
2023-05-13 21:51:41 -04:00
parent 84c7bc4b68
commit 0d28df0fd2
3 changed files with 16 additions and 16 deletions

View File

@@ -198,6 +198,18 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
)
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
# filter out bad data
dataset = Dataset.from_list(
[
d
for d in dataset
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
if cfg.local_rank == 0:
logging.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
@@ -208,18 +220,6 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
tokenizer, cfg, default_dataset_prepared_path
)
# filter out bad data
dataset = Dataset.from_list(
[
d
for d in dataset
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"