fix: consolidate handling of fp16, bf16, tf32 kwarg

This commit is contained in:
NanoCode012
2025-01-28 14:22:59 +07:00
parent c268a0157a
commit fd271b2547

View File

@@ -252,6 +252,14 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["warmup_steps"] = warmup_steps
training_args_kwargs["logging_steps"] = logging_steps
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
@@ -433,14 +441,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps):
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = self.cfg.bf16
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
training_arguments_kwargs["tf32"] = self.cfg.tf32
if self.cfg.seed is not None:
training_arguments_kwargs["seed"] = self.cfg.seed
@@ -1014,9 +1014,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
training_args_kwargs["lr_scheduler_type"] = (