From cb7cd3429fba1aa83d7827759f6e09e2441de409 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 09:21:08 +0900 Subject: [PATCH] Fix data.py lint --- src/axolotl/utils/data.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6d2123eea..32654e104 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,3 +1,5 @@ +"""Module containing data utilities for Axolotl""" + import logging from hashlib import md5 from pathlib import Path @@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets( md5( ( str(cfg.sequence_len) - + "@" # noqa: W503 - + "|".join( # noqa: W503 + + "@" + + "|".join( sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) ) - + "|" # noqa: W503 - + tokenizer_name # noqa: W503 + + "|" + + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets( logging.info(f"Unable to find prepared dataset in {prepared_ds_path}") logging.info("Loading raw datasets...") datasets = [] + # pylint: disable=invalid-name for d in cfg.datasets: ds: Union[Dataset, DatasetDict] = None ds_from_hub = False @@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets( samples = [] for d in datasets: - samples = samples + [i for i in d] + samples = samples + list(d) dataset = Dataset.from_list(samples).shuffle(seed=42) if cfg.local_rank == 0: logging.info( @@ -265,14 +268,14 @@ def load_prepare_datasets( md5( ( str(cfg.sequence_len) - + "@" # noqa: W503 - + str(max_packed_sequence_len) # noqa: W503 - + seed # noqa: W503 - + "|".join( # noqa: W503 + + "@" + + str(max_packed_sequence_len) + + seed + + "|".join( sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) ) - + "|" # noqa: W503 - + tokenizer_name # noqa: W503 + + "|" + + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -327,7 +330,7 @@ def load_prepare_datasets( logging.info( f"packing master dataset to len: {cfg.max_packed_sequence_len}" ) - dataset = Dataset.from_list([_ for _ in constant_len_dataset]) + dataset = Dataset.from_list(list(constant_len_dataset)) # filter out bad data dataset = Dataset.from_list( @@ -335,9 +338,9 @@ def load_prepare_datasets( d for d in dataset if len(d["input_ids"]) < cfg.sequence_len - and len(d["input_ids"]) > 0 # noqa: W503 - and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503 - and len(d["input_ids"]) == len(d["labels"]) # noqa: W503 + and len(d["input_ids"]) > 0 + and len(d["input_ids"]) == len(d["attention_mask"]) + and len(d["input_ids"]) == len(d["labels"]) ] )