wip flex block mask creation
This commit is contained in:
@@ -169,19 +169,26 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
seq_len_list = [
|
||||
get_seqlens_from_pos_ids(item["position_ids"])
|
||||
for item in features_
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
|
||||
out_features[i][feature] = np.concatenate(seq_len_list)
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
|
||||
|
||||
doc_mask =
|
||||
|
||||
out["attention_mask"]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user