Fix data.py lint

This commit is contained in:
NanoCode012
2023-05-29 09:21:08 +09:00
parent d57ba56746
commit cb7cd3429f

View File

@@ -1,3 +1,5 @@
"""Module containing data utilities for Axolotl"""
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
@@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets(
md5( md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" # noqa: W503 + "@"
+ "|".join( # noqa: W503 + "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
) )
+ "|" # noqa: W503 + "|"
+ tokenizer_name # noqa: W503 + tokenizer_name
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
) )
@@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets(
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}") logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...") logging.info("Loading raw datasets...")
datasets = [] datasets = []
# pylint: disable=invalid-name
for d in cfg.datasets: for d in cfg.datasets:
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
@@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets(
samples = [] samples = []
for d in datasets: for d in datasets:
samples = samples + [i for i in d] samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=42) dataset = Dataset.from_list(samples).shuffle(seed=42)
if cfg.local_rank == 0: if cfg.local_rank == 0:
logging.info( logging.info(
@@ -265,14 +268,14 @@ def load_prepare_datasets(
md5( md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" # noqa: W503 + "@"
+ str(max_packed_sequence_len) # noqa: W503 + str(max_packed_sequence_len)
+ seed # noqa: W503 + seed
+ "|".join( # noqa: W503 + "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
) )
+ "|" # noqa: W503 + "|"
+ tokenizer_name # noqa: W503 + tokenizer_name
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
) )
@@ -327,7 +330,7 @@ def load_prepare_datasets(
logging.info( logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}" f"packing master dataset to len: {cfg.max_packed_sequence_len}"
) )
dataset = Dataset.from_list([_ for _ in constant_len_dataset]) dataset = Dataset.from_list(list(constant_len_dataset))
# filter out bad data # filter out bad data
dataset = Dataset.from_list( dataset = Dataset.from_list(
@@ -335,9 +338,9 @@ def load_prepare_datasets(
d d
for d in dataset for d in dataset
if len(d["input_ids"]) < cfg.sequence_len if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0 # noqa: W503 and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503 and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503 and len(d["input_ids"]) == len(d["labels"])
] ]
) )