test v2batch w/ flex attn

This commit is contained in:
bursteratom
2025-02-13 00:11:45 -05:00
parent 0ef1f011fe
commit 82d04ea060

View File

@@ -830,7 +830,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.flex_attention is True:
collator = FlexBatchSamplerDataCollatorForSeq2Seq
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (