various bugfixes
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -92,11 +93,14 @@ class ConstantLengthDataset(IterableDataset):
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
else:
|
||||
logging.warning("dropping batch due to tensor size mismatch")
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer_len = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user