add preliminary fp8 support

This commit is contained in:
Wing Lian
2025-04-06 23:54:50 -04:00
parent 37630fc6ef
commit 1aec93cf9e
5 changed files with 14 additions and 9 deletions

View File

@@ -562,13 +562,16 @@ class AxolotlTrainer(
return res
def override_accelerator_args(self, **kwargs): # pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8=None, **kwargs
): # pylint: disable=unused-argument
ret_kwargs = {}
if os.environ.get("ACCELERATE_MIXED_PRECISION") == "fp8":
if fp8:
from accelerate.utils import AORecipeKwargs
ret_kwargs["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs

View File

@@ -45,7 +45,7 @@ def lce_forward(
Returns:
"""
print("=" * 30 + " lce_forward " + "=" * 30)
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None

View File

@@ -17,8 +17,8 @@ ORIGINAL_TRAINER_CODE = """
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "override_accelerator_args"):
additional_args = self.override_accelerator_args(**args)
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, **args)
if additional_args:
args.update(additional_args)

View File

@@ -558,11 +558,12 @@ class ModelLoader:
plugin_manager.pre_model_load(self.cfg)
# monkey patch to allow additional Accelerator init kwargs
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8()
patch_create_accelerate_code_for_fp8()
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (

View File

@@ -169,6 +169,7 @@ class AxolotlInputConfig(
bf16: Literal["auto"] | bool | None = "auto"
fp16: bool | None = None
fp8: bool | None = None
bfloat16: bool | None = None # for non-AMP cases
float16: bool | None = None # for non-AMP cases
tf32: bool | None = None