diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 12e85e15e..550801196 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -56,6 +56,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["bf16_full_eval"] = True else: training_arguments_kwargs["bf16"] = cfg.bf16 + training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False training_arguments_kwargs["tf32"] = cfg.tf32 training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps