Files
axolotl/tests/e2e/patched/test_fsdp2_qlora.py
Dan Saunders e758343cac FSDP2 + LoRA kernels (#2992)
* impl fix

* smoke tests

* patches for fsdp2 + qlora compat

* nit

* working fix

* working fix

* fix merge

* minifying patches; update bnb dep

* renaming; adding tests

* remove duplicate test, add dora guard

* generalize __torch_function__

* revert generalization

* update comments
2025-08-03 20:05:17 -04:00

132 lines
4.7 KiB
Python

"""Integration tests for FSDP Params4bit patches."""
from unittest.mock import Mock, patch
import bitsandbytes as bnb
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
patched_torch_function,
)
@pytest.fixture
def mock_params4bit():
"""Create a mock Params4bit instance with test attributes."""
mock_instance = Mock()
mock_instance.requires_grad = True
mock_instance.quant_state = "test_state"
mock_instance.blocksize = 128
mock_instance.compress_statistics = True
mock_instance.quant_type = "fp4"
mock_instance.quant_storage = "test_storage"
mock_instance.module = "test_module"
mock_instance.bnb_quantized = True
return mock_instance
class TestBnbTorchFunctionPatch:
"""Test the Params4bit.__torch_function__ patch."""
def test_apply_patch(self):
"""Test that the patch can be applied."""
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
apply_bnb_torch_function_patch()
assert hasattr(mock_cls, "__torch_function__")
assert isinstance(mock_cls.__torch_function__, classmethod)
# pylint: disable=redefined-outer-name
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
"""Test that torch.chunk preserves Params4bit attributes."""
mock_cls = Mock()
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
result = patched_torch_function(
mock_cls,
torch.chunk,
(type(mock_params4bit),),
args=(mock_params4bit, 2),
)
assert isinstance(result, tuple)
assert len(result) == 2
# Check that Params4bit constructor was called with preserved attributes
assert mock_cls.call_count == 2
for call in mock_cls.call_args_list:
kwargs = call[1]
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
assert kwargs["quant_state"] == mock_params4bit.quant_state
assert kwargs["blocksize"] == mock_params4bit.blocksize
# pylint: disable=redefined-outer-name
def test_other_functions_fallback(self, mock_params4bit):
"""Test that non-chunk/split functions use Parameter fallback."""
mock_cls = Mock()
fallback_result = torch.tensor([5, 6, 7])
with patch(
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
) as mock_fallback:
result = patched_torch_function(
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
)
# Should call Parameter.__torch_function__ and return its result
mock_fallback.assert_called_once()
assert result is fallback_result
mock_cls.assert_not_called()
class TestFSDPPatchIntegration:
"""Test FSDP patch integration."""
@pytest.mark.integration
def test_all_patches_together(self):
"""Test that all patches can be applied together."""
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
# Store original methods before patching
original_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
# pylint: disable=protected-access
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
# Verify patches were applied
current_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
if original_torch_function is not None:
assert (
current_torch_function != original_torch_function
), "Params4bit.__torch_function__ was not patched"
else:
assert (
current_torch_function is not None
), "Params4bit.__torch_function__ was not added"
# Check that FSDP methods were patched
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param
!= original_init_sharded
), "_init_sharded_param was not patched"
assert (
FSDPParam.init_unsharded_param != original_init_unsharded
), "init_unsharded_param was not patched"