* debug * debug * debug * revert unneeded change * add accelerator config to base trainer builder * add back accumulated_cache_size_limit setting * lint * accelerator constructor patch for single-GPU torch fp8 * lint * re-using existing fp8 code * lint * remove accelerate patch now fix in latest release * fix * docs * add fp8 + fsdp2 example * remove unused config * update config * smoke tests * add validator * add 2.7.0 guard for fsdp2 * fix * add config descriptions * add FSDP doc link * nit * set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather * better cfg for smoke tests * add test for accelerate patching * update fp8 validator
27 lines
643 B
Python
27 lines
643 B
Python
"""
|
|
Unit tests for trainer accelerator args monkeypatch
|
|
"""
|
|
|
|
import unittest
|
|
|
|
from axolotl.monkeypatch.trainer_accelerator_args import (
|
|
check_create_accelerate_code_is_patchable,
|
|
)
|
|
|
|
|
|
class TestTrainerAcceleratorArgs(unittest.TestCase):
|
|
"""
|
|
Unit test class for trainer accelerator args monkeypatch
|
|
"""
|
|
|
|
def test_check_create_accelerate_code_is_patchable(self):
|
|
"""
|
|
Test that the upstream transformers code is still patchable.
|
|
This will fail if the patched code changes upstream.
|
|
"""
|
|
assert check_create_accelerate_code_is_patchable()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|