add preliminary fp8 support
This commit is contained in:
@@ -562,13 +562,16 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def override_accelerator_args(self, **kwargs): # pylint: disable=unused-argument
|
def additional_accelerator_args(
|
||||||
|
self, fp8=None, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
ret_kwargs = {}
|
ret_kwargs = {}
|
||||||
if os.environ.get("ACCELERATE_MIXED_PRECISION") == "fp8":
|
if fp8:
|
||||||
from accelerate.utils import AORecipeKwargs
|
from accelerate.utils import AORecipeKwargs
|
||||||
|
|
||||||
ret_kwargs["mixed_precision"] = "fp8"
|
ret_kwargs["mixed_precision"] = "fp8"
|
||||||
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
|
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||||
|
|
||||||
return ret_kwargs
|
return ret_kwargs
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def lce_forward(
|
|||||||
Returns:
|
Returns:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("=" * 30 + " lce_forward " + "=" * 30)
|
# pylint: disable=duplicate-code
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions
|
||||||
if output_attentions is not None
|
if output_attentions is not None
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ ORIGINAL_TRAINER_CODE = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATCHED_TRAINER_CODE = """
|
PATCHED_TRAINER_CODE = """
|
||||||
if hasattr(self, "override_accelerator_args"):
|
if hasattr(self, "additional_accelerator_args"):
|
||||||
additional_args = self.override_accelerator_args(**args)
|
additional_args = self.additional_accelerator_args(fp8=True, **args)
|
||||||
if additional_args:
|
if additional_args:
|
||||||
args.update(additional_args)
|
args.update(additional_args)
|
||||||
|
|
||||||
|
|||||||
@@ -558,11 +558,12 @@ class ModelLoader:
|
|||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
# monkey patch to allow additional Accelerator init kwargs
|
# monkey patch to allow additional Accelerator init kwargs
|
||||||
from axolotl.monkeypatch.trainer_accelerator_args import (
|
if self.cfg.fp8:
|
||||||
patch_create_accelerate_code_for_fp8,
|
from axolotl.monkeypatch.trainer_accelerator_args import (
|
||||||
)
|
patch_create_accelerate_code_for_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
patch_create_accelerate_code_for_fp8()
|
patch_create_accelerate_code_for_fp8()
|
||||||
|
|
||||||
if self.cfg.adapter:
|
if self.cfg.adapter:
|
||||||
from axolotl.monkeypatch.transformers_fa_utils import (
|
from axolotl.monkeypatch.transformers_fa_utils import (
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
bf16: Literal["auto"] | bool | None = "auto"
|
bf16: Literal["auto"] | bool | None = "auto"
|
||||||
fp16: bool | None = None
|
fp16: bool | None = None
|
||||||
|
fp8: bool | None = None
|
||||||
bfloat16: bool | None = None # for non-AMP cases
|
bfloat16: bool | None = None # for non-AMP cases
|
||||||
float16: bool | None = None # for non-AMP cases
|
float16: bool | None = None # for non-AMP cases
|
||||||
tf32: bool | None = None
|
tf32: bool | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user