From fd271b25475d7616ce6fb2e2383a8c725c7452c5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 28 Jan 2025 14:22:59 +0700 Subject: [PATCH] fix: consolidate handling of fp16, bf16, tf32 kwarg --- src/axolotl/core/trainer_builder.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index da6384c8e..6fb36f299 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -252,6 +252,14 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["warmup_steps"] = warmup_steps training_args_kwargs["logging_steps"] = logging_steps + training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False + training_args_kwargs["tf32"] = self.cfg.tf32 + + if self.cfg.bf16 == "full": + training_args_kwargs["bf16_full_eval"] = True + else: + training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16 + if self.cfg.hub_model_id: training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id training_args_kwargs["push_to_hub"] = True @@ -433,14 +441,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def build(self, total_num_steps): training_arguments_kwargs = self._set_base_training_args(total_num_steps) - if self.cfg.bf16 == "full": - training_arguments_kwargs["bf16_full_eval"] = True - else: - training_arguments_kwargs["bf16"] = self.cfg.bf16 - training_arguments_kwargs["fp16"] = ( - self.cfg.fp16 and not self.cfg.bf16 - ) or False - training_arguments_kwargs["tf32"] = self.cfg.tf32 if self.cfg.seed is not None: training_arguments_kwargs["seed"] = self.cfg.seed @@ -1014,9 +1014,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.eval_strategy: training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy - if self.cfg.bf16 or self.cfg.bfloat16: - training_args_kwargs["bf16"] = True - training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding training_args_kwargs["lr_scheduler_type"] = (