Files
axolotl/tests/e2e/patched/test_fsdp2_qlora.py
Dan Saunders 79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00

31 lines
1004 B
Python

"""Integration tests for FSDP2 Params4bit patches."""
import pytest
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
class TestFSDPPatchIntegration:
"""Test FSDP patch integration."""
@pytest.mark.integration
def test_fsdp2_init_patches(self):
"""Test that all patches can be applied together."""
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
assert FSDPParam._init_sharded_param != original_init_sharded, (
"_init_sharded_param was not patched"
)
assert FSDPParam.init_unsharded_param != original_init_unsharded, (
"init_unsharded_param was not patched"
)