Fix data.py lint
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""Module containing data utilities for Axolotl"""
|
||||
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
@@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets(
|
||||
md5(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@" # noqa: W503
|
||||
+ "|".join( # noqa: W503
|
||||
+ "@"
|
||||
+ "|".join(
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|" # noqa: W503
|
||||
+ tokenizer_name # noqa: W503
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
).hexdigest()
|
||||
)
|
||||
@@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets(
|
||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||
logging.info("Loading raw datasets...")
|
||||
datasets = []
|
||||
# pylint: disable=invalid-name
|
||||
for d in cfg.datasets:
|
||||
ds: Union[Dataset, DatasetDict] = None
|
||||
ds_from_hub = False
|
||||
@@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
samples = []
|
||||
for d in datasets:
|
||||
samples = samples + [i for i in d]
|
||||
samples = samples + list(d)
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
@@ -265,14 +268,14 @@ def load_prepare_datasets(
|
||||
md5(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@" # noqa: W503
|
||||
+ str(max_packed_sequence_len) # noqa: W503
|
||||
+ seed # noqa: W503
|
||||
+ "|".join( # noqa: W503
|
||||
+ "@"
|
||||
+ str(max_packed_sequence_len)
|
||||
+ seed
|
||||
+ "|".join(
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|" # noqa: W503
|
||||
+ tokenizer_name # noqa: W503
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
).hexdigest()
|
||||
)
|
||||
@@ -327,7 +330,7 @@ def load_prepare_datasets(
|
||||
logging.info(
|
||||
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
||||
)
|
||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
|
||||
# filter out bad data
|
||||
dataset = Dataset.from_list(
|
||||
@@ -335,9 +338,9 @@ def load_prepare_datasets(
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0 # noqa: W503
|
||||
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
|
||||
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user