Lint and format

This commit is contained in:
NanoCode012
2023-05-29 03:45:42 +09:00
parent a98deb31a6
commit 392dfd9b07
9 changed files with 82 additions and 58 deletions

View File

@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
else:
example_len = 0
if (
not example_len
or buffer_len + int(add_concat_token) + example_len
> self.seq_length
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if (
labels.size() == input_ids.size()
and attention_mask.size() == input_ids.size()
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,