test v2batch w/ flex attn
This commit is contained in:
@@ -830,7 +830,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.flex_attention is True:
|
if self.cfg.flex_attention is True:
|
||||||
collator = FlexBatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif (
|
elif (
|
||||||
|
|||||||
Reference in New Issue
Block a user