Fix data.py lint
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
"""Module containing data utilities for Axolotl"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets(
|
|||||||
md5(
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@" # noqa: W503
|
+ "@"
|
||||||
+ "|".join( # noqa: W503
|
+ "|".join(
|
||||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||||
)
|
)
|
||||||
+ "|" # noqa: W503
|
+ "|"
|
||||||
+ tokenizer_name # noqa: W503
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
)
|
)
|
||||||
@@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||||
logging.info("Loading raw datasets...")
|
logging.info("Loading raw datasets...")
|
||||||
datasets = []
|
datasets = []
|
||||||
|
# pylint: disable=invalid-name
|
||||||
for d in cfg.datasets:
|
for d in cfg.datasets:
|
||||||
ds: Union[Dataset, DatasetDict] = None
|
ds: Union[Dataset, DatasetDict] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
@@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
for d in datasets:
|
for d in datasets:
|
||||||
samples = samples + [i for i in d]
|
samples = samples + list(d)
|
||||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
@@ -265,14 +268,14 @@ def load_prepare_datasets(
|
|||||||
md5(
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@" # noqa: W503
|
+ "@"
|
||||||
+ str(max_packed_sequence_len) # noqa: W503
|
+ str(max_packed_sequence_len)
|
||||||
+ seed # noqa: W503
|
+ seed
|
||||||
+ "|".join( # noqa: W503
|
+ "|".join(
|
||||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||||
)
|
)
|
||||||
+ "|" # noqa: W503
|
+ "|"
|
||||||
+ tokenizer_name # noqa: W503
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
)
|
)
|
||||||
@@ -327,7 +330,7 @@ def load_prepare_datasets(
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
||||||
)
|
)
|
||||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||||
|
|
||||||
# filter out bad data
|
# filter out bad data
|
||||||
dataset = Dataset.from_list(
|
dataset = Dataset.from_list(
|
||||||
@@ -335,9 +338,9 @@ def load_prepare_datasets(
|
|||||||
d
|
d
|
||||||
for d in dataset
|
for d in dataset
|
||||||
if len(d["input_ids"]) < cfg.sequence_len
|
if len(d["input_ids"]) < cfg.sequence_len
|
||||||
and len(d["input_ids"]) > 0 # noqa: W503
|
and len(d["input_ids"]) > 0
|
||||||
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
|
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||||
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
|
and len(d["input_ids"]) == len(d["labels"])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user