move filter to before saving so it doesn't happen everytime, update runpod manual script
This commit is contained in:
@@ -198,6 +198,18 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
)
|
||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
||||
|
||||
# filter out bad data
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
]
|
||||
)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
||||
@@ -208,18 +220,6 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
# filter out bad data
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
]
|
||||
)
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
logging.info(
|
||||
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
||||
|
||||
Reference in New Issue
Block a user