From 1aec93cf9e9984cefe5b63f7e7aaec318fb1a73b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 6 Apr 2025 23:54:50 -0400 Subject: [PATCH] add preliminary fp8 support --- src/axolotl/core/trainers/base.py | 7 +++++-- src/axolotl/integrations/liger/models/llama4.py | 2 +- src/axolotl/monkeypatch/trainer_accelerator_args.py | 4 ++-- src/axolotl/utils/models.py | 9 +++++---- src/axolotl/utils/schemas/config.py | 1 + 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index dd50f8ce7..bc3a200d4 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -562,13 +562,16 @@ class AxolotlTrainer( return res - def override_accelerator_args(self, **kwargs): # pylint: disable=unused-argument + def additional_accelerator_args( + self, fp8=None, **kwargs + ): # pylint: disable=unused-argument ret_kwargs = {} - if os.environ.get("ACCELERATE_MIXED_PRECISION") == "fp8": + if fp8: from accelerate.utils import AORecipeKwargs ret_kwargs["mixed_precision"] = "fp8" ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()] + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" return ret_kwargs diff --git a/src/axolotl/integrations/liger/models/llama4.py b/src/axolotl/integrations/liger/models/llama4.py index ee7f226cd..da35b114c 100644 --- a/src/axolotl/integrations/liger/models/llama4.py +++ b/src/axolotl/integrations/liger/models/llama4.py @@ -45,7 +45,7 @@ def lce_forward( Returns: """ - print("=" * 30 + " lce_forward " + "=" * 30) + # pylint: disable=duplicate-code output_attentions = ( output_attentions if output_attentions is not None diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index 8c68f2c8a..d87812c9f 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -17,8 +17,8 @@ ORIGINAL_TRAINER_CODE = """ """ PATCHED_TRAINER_CODE = """ - if hasattr(self, "override_accelerator_args"): - additional_args = self.override_accelerator_args(**args) + if hasattr(self, "additional_accelerator_args"): + additional_args = self.additional_accelerator_args(fp8=True, **args) if additional_args: args.update(additional_args) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 79f6c5a9b..367e69850 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -558,11 +558,12 @@ class ModelLoader: plugin_manager.pre_model_load(self.cfg) # monkey patch to allow additional Accelerator init kwargs - from axolotl.monkeypatch.trainer_accelerator_args import ( - patch_create_accelerate_code_for_fp8, - ) + if self.cfg.fp8: + from axolotl.monkeypatch.trainer_accelerator_args import ( + patch_create_accelerate_code_for_fp8, + ) - patch_create_accelerate_code_for_fp8() + patch_create_accelerate_code_for_fp8() if self.cfg.adapter: from axolotl.monkeypatch.transformers_fa_utils import ( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4995962df..0f9a3a1f9 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -169,6 +169,7 @@ class AxolotlInputConfig( bf16: Literal["auto"] | bool | None = "auto" fp16: bool | None = None + fp8: bool | None = None bfloat16: bool | None = None # for non-AMP cases float16: bool | None = None # for non-AMP cases tf32: bool | None = None