black formatting
This commit is contained in:
@@ -71,10 +71,18 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
|
||||
if (
|
||||
not example_len
|
||||
or buffer_len + int(add_concat_token) + example_len
|
||||
> self.seq_length
|
||||
):
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
|
||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
@@ -95,7 +103,9 @@ class ConstantLengthDataset(IterableDataset):
|
||||
labels.append(self.concat_token_id)
|
||||
|
||||
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
||||
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
attention_mask, dtype=torch.long
|
||||
)
|
||||
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
|
||||
Reference in New Issue
Block a user