Fix data.py lint

This commit is contained in:
NanoCode012
2023-05-29 09:21:08 +09:00
parent d57ba56746
commit cb7cd3429f

View File

@@ -1,3 +1,5 @@
"""Module containing data utilities for Axolotl"""
import logging
from hashlib import md5
from pathlib import Path
@@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets(
md5(
(
str(cfg.sequence_len)
+ "@" # noqa: W503
+ "|".join( # noqa: W503
+ "@"
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
+ "|"
+ tokenizer_name
).encode("utf-8")
).hexdigest()
)
@@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets(
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
datasets = []
# pylint: disable=invalid-name
for d in cfg.datasets:
ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False
@@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets(
samples = []
for d in datasets:
samples = samples + [i for i in d]
samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=42)
if cfg.local_rank == 0:
logging.info(
@@ -265,14 +268,14 @@ def load_prepare_datasets(
md5(
(
str(cfg.sequence_len)
+ "@" # noqa: W503
+ str(max_packed_sequence_len) # noqa: W503
+ seed # noqa: W503
+ "|".join( # noqa: W503
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
+ "|"
+ tokenizer_name
).encode("utf-8")
).hexdigest()
)
@@ -327,7 +330,7 @@ def load_prepare_datasets(
logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
)
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
dataset = Dataset.from_list(list(constant_len_dataset))
# filter out bad data
dataset = Dataset.from_list(
@@ -335,9 +338,9 @@ def load_prepare_datasets(
d
for d in dataset
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0 # noqa: W503
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)