wip flex block mask creation
This commit is contained in:
@@ -169,19 +169,26 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature == "length":
|
if feature == "length":
|
||||||
continue
|
continue
|
||||||
if feature == "attention_mask":
|
if feature == "attention_mask":
|
||||||
seq_len_list = [
|
arrays = [
|
||||||
get_seqlens_from_pos_ids(item["position_ids"])
|
(i + 1) * np.array(item[feature])
|
||||||
for item in features_
|
for i, item in enumerate(features_)
|
||||||
if feature in item
|
if feature in item
|
||||||
]
|
]
|
||||||
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
out_features[i][feature] = np.concatenate(seq_len_list)
|
|
||||||
else:
|
else:
|
||||||
arrays = [
|
arrays = [
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
|
||||||
|
|
||||||
|
doc_mask =
|
||||||
|
|
||||||
|
out["attention_mask"]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
Reference in New Issue
Block a user