basic torchao fp8 mixed precision training (#2926)
* debug * debug * debug * revert unneeded change * add accelerator config to base trainer builder * add back accumulated_cache_size_limit setting * lint * accelerator constructor patch for single-GPU torch fp8 * lint * re-using existing fp8 code * lint * remove accelerate patch now fix in latest release * fix * docs * add fp8 + fsdp2 example * remove unused config * update config * smoke tests * add validator * add 2.7.0 guard for fsdp2 * fix * add config descriptions * add FSDP doc link * nit * set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather * better cfg for smoke tests * add test for accelerate patching * update fp8 validator
This commit is contained in:
@@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
from typing import Callable, Literal, Optional
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
@@ -522,15 +522,25 @@ class AxolotlTrainer(
|
||||
|
||||
return res
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def additional_accelerator_args(
|
||||
self, fp8=None, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
ret_kwargs = {}
|
||||
if fp8:
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
from torchao.float8 import Float8LinearConfig
|
||||
|
||||
# By default, Float8LinearConfig is instantiated using the "tensorwise"
|
||||
# scaling strategy. See more details here:
|
||||
# https://github.com/pytorch/ao/tree/main/torchao/float8.
|
||||
config = Float8LinearConfig(
|
||||
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
|
||||
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True,
|
||||
)
|
||||
|
||||
ret_kwargs["mixed_precision"] = "fp8"
|
||||
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
|
||||
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||
|
||||
return ret_kwargs
|
||||
|
||||
@@ -154,7 +154,9 @@ class PatchManager:
|
||||
patch_create_accelerate_code_for_fp8,
|
||||
)
|
||||
|
||||
patch_create_accelerate_code_for_fp8()
|
||||
patch_create_accelerate_code_for_fp8(
|
||||
self.cfg.fp8_enable_fsdp_float8_all_gather
|
||||
)
|
||||
|
||||
def _apply_flash_attention_peft_patches(self):
|
||||
"""Apply patches for Flash Attention with PEFT."""
|
||||
|
||||
@@ -18,7 +18,7 @@ ORIGINAL_TRAINER_CODE = """
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
if hasattr(self, "additional_accelerator_args"):
|
||||
additional_args = self.additional_accelerator_args(fp8=True, **args)
|
||||
additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)
|
||||
if additional_args:
|
||||
args.update(additional_args)
|
||||
|
||||
@@ -38,9 +38,9 @@ def check_create_accelerate_code_is_patchable() -> bool:
|
||||
return ORIGINAL_TRAINER_CODE in create_code
|
||||
|
||||
|
||||
def patch_create_accelerate_code_for_fp8():
|
||||
def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):
|
||||
"""
|
||||
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
|
||||
Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -54,7 +54,10 @@ def patch_create_accelerate_code_for_fp8():
|
||||
if ORIGINAL_TRAINER_CODE not in create_code:
|
||||
return
|
||||
|
||||
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
||||
patched_trainer_code = PATCHED_TRAINER_CODE.format(
|
||||
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather
|
||||
)
|
||||
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)
|
||||
create_code = create_code.replace(
|
||||
"def create_accelerator_and_postprocess(",
|
||||
"def fixed_create_accelerator_and_postprocess(",
|
||||
|
||||
@@ -343,7 +343,20 @@ class AxolotlInputConfig(
|
||||
fp16: bool | None = Field(
|
||||
default=None, json_schema_extra={"description": "Use CUDA fp16"}
|
||||
)
|
||||
fp8: bool | None = None
|
||||
fp8: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Enable FP8 mixed precision training using TorchAO. Best "
|
||||
"used in combination with torch.compile."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Enable FSDP float8 all-gather optimization for FP8 training. Can "
|
||||
"improve training speed by 10-15% when FSDP is enabled."
|
||||
},
|
||||
)
|
||||
bfloat16: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -360,6 +360,36 @@ class TrainingValidationMixin:
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fp8_config(cls, data):
|
||||
if data.get("fp8") and not data.get("torch_compile"):
|
||||
LOG.warning(
|
||||
"torch_compile is strongly recommended for FP8 training in order to "
|
||||
"see speed improvements. Please consider setting `torch_compile: "
|
||||
"true` in your config."
|
||||
)
|
||||
if data.get("fp8") and (
|
||||
data.get("fsdp_config", {}).get("activation_checkpointing", False) is True
|
||||
or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False)
|
||||
is True
|
||||
):
|
||||
LOG.warning(
|
||||
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
|
||||
"training. Please considering setting `activation_checkpointing: false` "
|
||||
"in your FSDP config."
|
||||
)
|
||||
if (
|
||||
data.get("fp8_enable_fsdp_float8_all_gather")
|
||||
and not data.get("fsdp_version", None) == 2
|
||||
):
|
||||
raise ValueError(
|
||||
"fp8_enable_fsdp_float8_all_gather requires FSDP2 (fsdp_version: 2) "
|
||||
"to be used."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_use_reentrant_mismatch(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user