diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 908947876..c6ec59132 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -911,7 +911,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator = DataCollatorForSeq2Seq kwargs["return_tensors"] = "pt" - kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree + if issubclass(collator, DataCollatorForSeq2Seq): + print(type(collator), issubclass(collator, DataCollatorForSeq2Seq)) + kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree return collator( *collator_args,