This commit is contained in:
bursteratom
2025-02-02 00:51:43 -05:00
parent b692d394b1
commit b832b11c8f

View File

@@ -180,8 +180,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out = super().__call__(out_features, return_tensors=return_tensors)
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
# raise ValueError(f"{out['attention_mask'].shape}")
return out