fix: consolidate handling of fp16, bf16, tf32 kwarg
This commit is contained in:
@@ -252,6 +252,14 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||
training_args_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
||||
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_args_kwargs["push_to_hub"] = True
|
||||
@@ -433,14 +441,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def build(self, total_num_steps):
|
||||
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = self.cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = (
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
|
||||
if self.cfg.seed is not None:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
@@ -1014,9 +1014,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.eval_strategy:
|
||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
|
||||
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
||||
training_args_kwargs["lr_scheduler_type"] = (
|
||||
|
||||
Reference in New Issue
Block a user