various bugfixes

This commit is contained in:
Wing Lian
2023-04-14 21:37:07 -04:00
parent 45f77dd51e
commit 80b2ed29d8
5 changed files with 33 additions and 26 deletions

View File

@@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset):
buffer_len = 0
if example:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=torch.long
)
labels_with_concat = torch.tensor(labels, dtype=torch.long)
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=torch.long
)
labels_with_concat = torch.tensor(labels, dtype=torch.long)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer_len += len(input_ids)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer_len += len(input_ids)