add preliminary fp8 support
This commit is contained in:
@@ -562,13 +562,16 @@ class AxolotlTrainer(
|
||||
|
||||
return res
|
||||
|
||||
def override_accelerator_args(self, **kwargs): # pylint: disable=unused-argument
|
||||
def additional_accelerator_args(
|
||||
self, fp8=None, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
ret_kwargs = {}
|
||||
if os.environ.get("ACCELERATE_MIXED_PRECISION") == "fp8":
|
||||
if fp8:
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
|
||||
ret_kwargs["mixed_precision"] = "fp8"
|
||||
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||
|
||||
return ret_kwargs
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def lce_forward(
|
||||
Returns:
|
||||
"""
|
||||
|
||||
print("=" * 30 + " lce_forward " + "=" * 30)
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
|
||||
@@ -17,8 +17,8 @@ ORIGINAL_TRAINER_CODE = """
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
if hasattr(self, "override_accelerator_args"):
|
||||
additional_args = self.override_accelerator_args(**args)
|
||||
if hasattr(self, "additional_accelerator_args"):
|
||||
additional_args = self.additional_accelerator_args(fp8=True, **args)
|
||||
if additional_args:
|
||||
args.update(additional_args)
|
||||
|
||||
|
||||
@@ -558,11 +558,12 @@ class ModelLoader:
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
|
||||
# monkey patch to allow additional Accelerator init kwargs
|
||||
from axolotl.monkeypatch.trainer_accelerator_args import (
|
||||
patch_create_accelerate_code_for_fp8,
|
||||
)
|
||||
if self.cfg.fp8:
|
||||
from axolotl.monkeypatch.trainer_accelerator_args import (
|
||||
patch_create_accelerate_code_for_fp8,
|
||||
)
|
||||
|
||||
patch_create_accelerate_code_for_fp8()
|
||||
patch_create_accelerate_code_for_fp8()
|
||||
|
||||
if self.cfg.adapter:
|
||||
from axolotl.monkeypatch.transformers_fa_utils import (
|
||||
|
||||
@@ -169,6 +169,7 @@ class AxolotlInputConfig(
|
||||
|
||||
bf16: Literal["auto"] | bool | None = "auto"
|
||||
fp16: bool | None = None
|
||||
fp8: bool | None = None
|
||||
bfloat16: bool | None = None # for non-AMP cases
|
||||
float16: bool | None = None # for non-AMP cases
|
||||
tf32: bool | None = None
|
||||
|
||||
Reference in New Issue
Block a user