various bugfixes
This commit is contained in:
@@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset):
|
||||
buffer_len = 0
|
||||
|
||||
if example:
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
# just going to drop data points that are too long
|
||||
if len(example["input_ids"]) <= self.seq_length:
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
attention_mask.append(1)
|
||||
labels.append(self.concat_token_id)
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
attention_mask.append(1)
|
||||
labels.append(self.concat_token_id)
|
||||
|
||||
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
attention_mask, dtype=torch.long
|
||||
)
|
||||
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
||||
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
attention_mask, dtype=torch.long
|
||||
)
|
||||
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
buffer["labels"].append(labels_with_concat)
|
||||
buffer_len += len(input_ids)
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
buffer["labels"].append(labels_with_concat)
|
||||
buffer_len += len(input_ids)
|
||||
|
||||
Reference in New Issue
Block a user