test v2batch w/ flex attn
This commit is contained in:
@@ -830,7 +830,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
kwargs.pop("max_length")
|
||||
elif use_batch_sampler_collator:
|
||||
if self.cfg.flex_attention is True:
|
||||
collator = FlexBatchSamplerDataCollatorForSeq2Seq
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
|
||||
Reference in New Issue
Block a user