Lint and format

This commit is contained in:
NanoCode012
2023-05-29 03:45:42 +09:00
parent a98deb31a6
commit 392dfd9b07
9 changed files with 82 additions and 58 deletions

View File

@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
else:
example_len = 0
if (
not example_len
or buffer_len + int(add_concat_token) + example_len
> self.seq_length
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if (
labels.size() == input_ids.size()
and attention_mask.size() == input_ids.size()
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,

View File

@@ -1,14 +1,12 @@
import logging
from hashlib import md5
from pathlib import Path
from typing import Union
from typing import Tuple, Union
from datasets import (
load_from_disk,
load_dataset,
IterableDataset,
Dataset,
concatenate_datasets,
DatasetDict,
)
from huggingface_hub import hf_hub_download
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
md5(
(
str(cfg.sequence_len)
+ "@"
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|"
+ tokenizer_name
+ "@" # noqa: W503
+ "|".join( # noqa: W503
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
).encode("utf-8")
).hexdigest()
)
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
)
dataset = dataset["train"]
except:
except Exception: # pylint: disable=broad-except
pass
if dataset:
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files
)
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
ds: Dataset = load_dataset(
"json", data_files=fp, streaming=False, split=None
)
if not ds:
raise Exception("unhandled dataset load")
raise ValueError("unhandled dataset load")
# support for using a subset of the data
if d.shards:
if "train" in ds:
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
num_shards=d.shards, index=0
)
else:
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
ds: Dataset = ds.shuffle(seed=42).shard(
num_shards=d.shards, index=0
)
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
) -> (Dataset, Dataset):
) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
@@ -259,12 +265,14 @@ def load_prepare_datasets(
md5(
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|"
+ tokenizer_name
+ "@" # noqa: W503
+ str(max_packed_sequence_len) # noqa: W503
+ seed # noqa: W503
+ "|".join( # noqa: W503
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
).encode("utf-8")
).hexdigest()
)
@@ -285,7 +293,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
)
dataset = dataset["train"]
except:
except Exception: # pylint: disable=broad-except
pass
if dataset:
@@ -327,9 +335,9 @@ def load_prepare_datasets(
d
for d in dataset
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
and len(d["input_ids"]) > 0 # noqa: W503
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
]
)