basic torchao fp8 mixed precision training (#2926)

* debug

* debug

* debug

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint

* accelerator constructor patch for single-GPU torch fp8

* lint

* re-using existing fp8 code

* lint

* remove accelerate patch now fix in latest release

* fix

* docs

* add fp8 + fsdp2 example

* remove unused config

* update config

* smoke tests

* add validator

* add 2.7.0 guard for fsdp2

* fix

* add config descriptions

* add FSDP doc link

* nit

* set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather

* better cfg for smoke tests

* add test for accelerate patching

* update fp8 validator
This commit is contained in:
Dan Saunders
2025-07-22 16:27:47 -04:00
committed by GitHub
parent b86a1d47b0
commit 208fb7b8e7
11 changed files with 503 additions and 10 deletions

View File

@@ -7,7 +7,7 @@ from __future__ import annotations
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Callable, Literal, Optional
from typing import Any, Callable, Literal, Optional
import datasets
import torch
@@ -522,15 +522,25 @@ class AxolotlTrainer(
return res
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8=None, **kwargs
): # pylint: disable=unused-argument
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs
from torchao.float8 import Float8LinearConfig
# By default, Float8LinearConfig is instantiated using the "tensorwise"
# scaling strategy. See more details here:
# https://github.com/pytorch/ao/tree/main/torchao/float8.
config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True,
)
ret_kwargs["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs

View File

@@ -154,7 +154,9 @@ class PatchManager:
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8()
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)
def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT."""

View File

@@ -18,7 +18,7 @@ ORIGINAL_TRAINER_CODE = """
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, **args)
additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)
if additional_args:
args.update(additional_args)
@@ -38,9 +38,9 @@ def check_create_accelerate_code_is_patchable() -> bool:
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8():
def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.
"""
try:
@@ -54,7 +54,10 @@ def patch_create_accelerate_code_for_fp8():
if ORIGINAL_TRAINER_CODE not in create_code:
return
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
patched_trainer_code = PATCHED_TRAINER_CODE.format(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather
)
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",

View File

@@ -343,7 +343,20 @@ class AxolotlInputConfig(
fp16: bool | None = Field(
default=None, json_schema_extra={"description": "Use CUDA fp16"}
)
fp8: bool | None = None
fp8: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable FP8 mixed precision training using TorchAO. Best "
"used in combination with torch.compile."
},
)
fp8_enable_fsdp_float8_all_gather: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable FSDP float8 all-gather optimization for FP8 training. Can "
"improve training speed by 10-15% when FSDP is enabled."
},
)
bfloat16: bool | None = Field(
default=None,
json_schema_extra={

View File

@@ -360,6 +360,36 @@ class TrainingValidationMixin:
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
return self
@model_validator(mode="before")
@classmethod
def check_fp8_config(cls, data):
if data.get("fp8") and not data.get("torch_compile"):
LOG.warning(
"torch_compile is strongly recommended for FP8 training in order to "
"see speed improvements. Please consider setting `torch_compile: "
"true` in your config."
)
if data.get("fp8") and (
data.get("fsdp_config", {}).get("activation_checkpointing", False) is True
or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False)
is True
):
LOG.warning(
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
"training. Please considering setting `activation_checkpointing: false` "
"in your FSDP config."
)
if (
data.get("fp8_enable_fsdp_float8_all_gather")
and not data.get("fsdp_version", None) == 2
):
raise ValueError(
"fp8_enable_fsdp_float8_all_gather requires FSDP2 (fsdp_version: 2) "
"to be used."
)
return data
@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):