Lint and format
This commit is contained in:
@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if (
|
||||
not example_len
|
||||
or buffer_len + int(add_concat_token) + example_len
|
||||
> self.seq_length
|
||||
if not example_len or (
|
||||
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
||||
):
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if (
|
||||
labels.size() == input_ids.size()
|
||||
and attention_mask.size() == input_ids.size()
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
):
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
|
||||
Reference in New Issue
Block a user