various bugfixes

This commit is contained in:
Wing Lian
2023-04-19 17:04:34 -04:00
parent 2624bc2f11
commit 94f5e415a3
6 changed files with 63 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
import logging
from typing import List
import torch
@@ -92,11 +93,14 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
else:
logging.warning("dropping batch due to tensor size mismatch")
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0