black formatting

This commit is contained in:
Wing Lian
2023-04-14 07:25:52 -04:00
parent 8d959a7e26
commit a6028d302e
6 changed files with 92 additions and 55 deletions

View File

@@ -71,10 +71,18 @@ class ConstantLengthDataset(IterableDataset):
else:
example_len = 0
if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
if (
not example_len
or buffer_len + int(add_concat_token) + example_len
> self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
yield {
"input_ids": input_ids,
@@ -95,7 +103,9 @@ class ConstantLengthDataset(IterableDataset):
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=torch.long
)
labels_with_concat = torch.tensor(labels, dtype=torch.long)
buffer["input_ids"].append(input_ids_with_concat)