diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 01d9792bc..6a04df99e 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -180,8 +180,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out = super().__call__(out_features, return_tensors=return_tensors) collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"]) - # out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) - out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) + out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) + # out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) # raise ValueError(f"{out['attention_mask'].shape}") return out