basic torchao fp8 mixed precision training (#2926)

* debug

* debug

* debug

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint

* accelerator constructor patch for single-GPU torch fp8

* lint

* re-using existing fp8 code

* lint

* remove accelerate patch now fix in latest release

* fix

* docs

* add fp8 + fsdp2 example

* remove unused config

* update config

* smoke tests

* add validator

* add 2.7.0 guard for fsdp2

* fix

* add config descriptions

* add FSDP doc link

* nit

* set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather

* better cfg for smoke tests

* add test for accelerate patching

* update fp8 validator
This commit is contained in:
Dan Saunders
2025-07-22 16:27:47 -04:00
committed by GitHub
parent b86a1d47b0
commit 208fb7b8e7
11 changed files with 503 additions and 10 deletions

View File

@@ -0,0 +1,26 @@
"""
Unit tests for trainer accelerator args monkeypatch
"""
import unittest
from axolotl.monkeypatch.trainer_accelerator_args import (
check_create_accelerate_code_is_patchable,
)
class TestTrainerAcceleratorArgs(unittest.TestCase):
"""
Unit test class for trainer accelerator args monkeypatch
"""
def test_check_create_accelerate_code_is_patchable(self):
"""
Test that the upstream transformers code is still patchable.
This will fail if the patched code changes upstream.
"""
assert check_create_accelerate_code_is_patchable()
if __name__ == "__main__":
unittest.main()