wip flex block mask creation

This commit is contained in:
bursteratom
2025-01-29 00:25:25 -05:00
parent b31796a681
commit ba88bc7840

View File

@@ -169,19 +169,26 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature == "length": if feature == "length":
continue continue
if feature == "attention_mask": if feature == "attention_mask":
seq_len_list = [ arrays = [
get_seqlens_from_pos_ids(item["position_ids"]) (i + 1) * np.array(item[feature])
for item in features_ for i, item in enumerate(features_)
if feature in item if feature in item
] ]
out_features[i][feature] = np.concatenate(arrays)
out_features[i][feature] = np.concatenate(seq_len_list)
else: else:
arrays = [ arrays = [
np.array(item[feature]) for item in features_ if feature in item np.array(item[feature]) for item in features_ if feature in item
] ]
out_features[i][feature] = np.concatenate(arrays) out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors) out = super().__call__(out_features, return_tensors=return_tensors)
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
doc_mask =
out["attention_mask"]
return out
@dataclass @dataclass