basic torchao fp8 mixed precision training (#2926)
* debug * debug * debug * revert unneeded change * add accelerator config to base trainer builder * add back accumulated_cache_size_limit setting * lint * accelerator constructor patch for single-GPU torch fp8 * lint * re-using existing fp8 code * lint * remove accelerate patch now fix in latest release * fix * docs * add fp8 + fsdp2 example * remove unused config * update config * smoke tests * add validator * add 2.7.0 guard for fsdp2 * fix * add config descriptions * add FSDP doc link * nit * set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather * better cfg for smoke tests * add test for accelerate patching * update fp8 validator
This commit is contained in:
26
tests/monkeypatch/test_trainer_accelerator_args.py
Normal file
26
tests/monkeypatch/test_trainer_accelerator_args.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Unit tests for trainer accelerator args monkeypatch
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.trainer_accelerator_args import (
|
||||
check_create_accelerate_code_is_patchable,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerAcceleratorArgs(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for trainer accelerator args monkeypatch
|
||||
"""
|
||||
|
||||
def test_check_create_accelerate_code_is_patchable(self):
|
||||
"""
|
||||
Test that the upstream transformers code is still patchable.
|
||||
This will fail if the patched code changes upstream.
|
||||
"""
|
||||
assert check_create_accelerate_code_is_patchable()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user