* upgrade transformers==4.55.1 * also upgrade bnb * remove bnb params4bit patch (upstreamed) * use latest causal-conv1d * fix patching ring-flash-attn with now missing imports --------- Co-authored-by: Dan Saunders <danjsaund@gmail.com>
34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
"""Integration tests for FSDP2 Params4bit patches."""
|
|
|
|
import pytest
|
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
|
|
|
|
|
class TestFSDPPatchIntegration:
|
|
"""Test FSDP patch integration."""
|
|
|
|
@pytest.mark.integration
|
|
def test_fsdp2_init_patches(self):
|
|
"""Test that all patches can be applied together."""
|
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
|
apply_init_sharded_param_patch,
|
|
apply_init_unsharded_param_patch,
|
|
)
|
|
|
|
# pylint: disable=protected-access
|
|
original_init_sharded = FSDPParam._init_sharded_param
|
|
original_init_unsharded = FSDPParam.init_unsharded_param
|
|
|
|
# Apply patches
|
|
apply_init_sharded_param_patch()
|
|
apply_init_unsharded_param_patch()
|
|
|
|
assert (
|
|
# pylint: disable=protected-access
|
|
FSDPParam._init_sharded_param
|
|
!= original_init_sharded
|
|
), "_init_sharded_param was not patched"
|
|
assert (
|
|
FSDPParam.init_unsharded_param != original_init_unsharded
|
|
), "init_unsharded_param was not patched"
|