diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 784b3b697..1ca992b83 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -830,7 +830,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): kwargs.pop("max_length") elif use_batch_sampler_collator: if self.cfg.flex_attention is True: - collator = FlexBatchSamplerDataCollatorForSeq2Seq + collator = V2BatchSamplerDataCollatorForSeq2Seq elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq elif (