Lint and format
This commit is contained in:
@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if (
|
||||
not example_len
|
||||
or buffer_len + int(add_concat_token) + example_len
|
||||
> self.seq_length
|
||||
if not example_len or (
|
||||
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
||||
):
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if (
|
||||
labels.size() == input_ids.size()
|
||||
and attention_mask.size() == input_ids.size()
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
):
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
load_from_disk,
|
||||
load_dataset,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
concatenate_datasets,
|
||||
DatasetDict,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
|
||||
md5(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
+ "@" # noqa: W503
|
||||
+ "|".join( # noqa: W503
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|" # noqa: W503
|
||||
+ tokenizer_name # noqa: W503
|
||||
).encode("utf-8")
|
||||
).hexdigest()
|
||||
)
|
||||
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
||||
)
|
||||
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
|
||||
ds: Dataset = load_dataset(
|
||||
"json", data_files=fp, streaming=False, split=None
|
||||
)
|
||||
if not ds:
|
||||
raise Exception("unhandled dataset load")
|
||||
raise ValueError("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
||||
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
||||
ds: Dataset = ds.shuffle(seed=42).shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
||||
) -> (Dataset, Dataset):
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -259,12 +265,14 @@ def load_prepare_datasets(
|
||||
md5(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ str(max_packed_sequence_len)
|
||||
+ seed
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
+ "@" # noqa: W503
|
||||
+ str(max_packed_sequence_len) # noqa: W503
|
||||
+ seed # noqa: W503
|
||||
+ "|".join( # noqa: W503
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|" # noqa: W503
|
||||
+ tokenizer_name # noqa: W503
|
||||
).encode("utf-8")
|
||||
).hexdigest()
|
||||
)
|
||||
@@ -285,7 +293,7 @@ def load_prepare_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
@@ -327,9 +335,9 @@ def load_prepare_datasets(
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
and len(d["input_ids"]) > 0 # noqa: W503
|
||||
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
|
||||
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user