non-seq2se1 collator fix

This commit is contained in:
Dan Saunders
2025-03-17 13:42:49 +00:00
parent 64c203cdef
commit 2727d86544

View File

@@ -911,7 +911,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
if issubclass(collator, DataCollatorForSeq2Seq):
print(type(collator), issubclass(collator, DataCollatorForSeq2Seq))
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator(
*collator_args,