Files
axolotl/tests/monkeypatch/test_trainer_accelerator_args.py
Dan Saunders 208fb7b8e7 basic torchao fp8 mixed precision training (#2926)
* debug

* debug

* debug

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint

* accelerator constructor patch for single-GPU torch fp8

* lint

* re-using existing fp8 code

* lint

* remove accelerate patch now fix in latest release

* fix

* docs

* add fp8 + fsdp2 example

* remove unused config

* update config

* smoke tests

* add validator

* add 2.7.0 guard for fsdp2

* fix

* add config descriptions

* add FSDP doc link

* nit

* set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather

* better cfg for smoke tests

* add test for accelerate patching

* update fp8 validator
2025-07-22 16:27:47 -04:00

27 lines
643 B
Python

"""
Unit tests for trainer accelerator args monkeypatch
"""
import unittest
from axolotl.monkeypatch.trainer_accelerator_args import (
check_create_accelerate_code_is_patchable,
)
class TestTrainerAcceleratorArgs(unittest.TestCase):
"""
Unit test class for trainer accelerator args monkeypatch
"""
def test_check_create_accelerate_code_is_patchable(self):
"""
Test that the upstream transformers code is still patchable.
This will fail if the patched code changes upstream.
"""
assert check_create_accelerate_code_is_patchable()
if __name__ == "__main__":
unittest.main()