From 8836986a920bfe016933a142cac91f3105100df4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 10 Nov 2023 02:35:19 -0500 Subject: [PATCH] support for fp8 --- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/config.py | 4 +++- src/axolotl/utils/trainer.py | 2 ++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 226988037..52482cc71 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -483,6 +483,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["fp16"] = ( self.cfg.fp16 and not self.cfg.bf16 ) or False + if self.cfg.fp8: + training_arguments_kwargs["fp16"] = False + training_arguments_kwargs["bf16"] = False + training_arguments_kwargs["tf32"] = self.cfg.tf32 training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 81660ae65..ccee27be9 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -70,7 +70,9 @@ def normalize_config(cfg): else: torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False - if cfg.bf16 or cfg.bfloat16: + if cfg.fp8: + cfg.torch_dtype = torch.bfloat16 + elif cfg.bf16 or cfg.bfloat16: cfg.torch_dtype = torch.bfloat16 elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: cfg.torch_dtype = torch.float16 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f93316cde..5997d28a0 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -268,6 +268,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ setup_fsdp_envs(cfg) elif cfg.deepspeed: os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + if cfg.fp8: + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) trainer_builder.train_dataset = train_dataset