diff --git a/requirements.txt b/requirements.txt index 2d72f307a..4e82dfd89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ # START section of dependencies that don't install on Darwin/MacOS -bitsandbytes==0.46.0 +bitsandbytes==0.46.1 triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 63c9e57bd..82ec91107 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -14,6 +14,7 @@ from typing import Callable import torch from bitsandbytes.functional import QuantState from torch import nn +from torch.distributed.tensor import DTensor from .geglu import geglu_backward, geglu_forward from .quantize import dequantize @@ -54,8 +55,21 @@ def get_lora_parameters( if hasattr(proj, "active_adapters") else proj.active_adapter ) - A = proj.lora_A[active_adapter].weight - B = proj.lora_B[active_adapter].weight + + linear_A = proj.lora_A[active_adapter] + linear_B = proj.lora_B[active_adapter] + + # This manual unsharding is needed for FSDP2 + LoRA kernels compatibility. + # We fuse linear layers + LoRA adapters calculations into a single + # torch.autograd.Function, bypassing the registered unshard / reshard behavior. + # Note that we don't apply resharding later in this module (it gets messy quickly), + # but LoRA parameters are generally small enough that this is not an issue. + if isinstance(linear_A.weight, DTensor): + linear_A.unshard() + linear_B.unshard() + + A = linear_A.weight + B = linear_B.weight s = proj.scaling[active_adapter] quant_state = getattr(W, "quant_state", None) @@ -102,8 +116,8 @@ def matmul_lora( del W if A is not None: - A, B = A.t(), B.t() - out += (X @ A.to(dtype)) @ (s * B.to(dtype)) + A, B = A.t().to(dtype), B.t().to(dtype) + out += s * X @ A @ B return out.view(batch, seq_len, -1) if reshape else out diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 9eb779113..e16f03649 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -65,6 +65,7 @@ class PatchManager: self._patch_llama_derived_model() self._apply_mistral_cross_entropy_patch() self._apply_self_attention_lora_patch() + self._apply_fsdp2_bnb_patches() def apply_post_plugin_pre_model_load_patches(self): """Apply post plugin-pre_model_load load patches based on config.""" @@ -260,6 +261,23 @@ class PatchManager: has_remote_code=has_remote_code, ) + def _apply_fsdp2_bnb_patches(self): + """Apply FSDP2 BNB patches.""" + if ( + self.cfg.fsdp_config + and str(self.cfg.fsdp_version) == "2" + and self.cfg.adapter == "qlora" + ): + from axolotl.monkeypatch.fsdp2_qlora import ( + apply_bnb_torch_function_patch, + apply_init_sharded_param_patch, + apply_init_unsharded_param_patch, + ) + + apply_bnb_torch_function_patch() + apply_init_sharded_param_patch() + apply_init_unsharded_param_patch() + def _apply_tiled_mlp(self, model_type: str): if self.cfg.tiled_mlp: from axolotl.monkeypatch.tiled_mlp import ( diff --git a/src/axolotl/monkeypatch/fsdp2_qlora.py b/src/axolotl/monkeypatch/fsdp2_qlora.py new file mode 100644 index 000000000..a2cb7e472 --- /dev/null +++ b/src/axolotl/monkeypatch/fsdp2_qlora.py @@ -0,0 +1,205 @@ +""" +Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as +our LoRA / QLoRA Triton kernels to work with FSDP2. + +This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes +Params4bit parameters. +""" + +import importlib +import inspect + +import torch +from torch.nn import Parameter + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def patched_torch_function(cls, func, types, args=(), kwargs=None): + """ + Patched version of Params4bit.__torch_function__ for preserving Params4bit + class identity and attributes. + """ + if kwargs is None: + kwargs = {} + + if func in [torch.chunk, torch.split]: + tensor = args[0] + result = Parameter.__torch_function__(func, types, args, kwargs) + + if isinstance(result, tuple): + return tuple( + cls( + data=chunk, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + for chunk in result + ) + + return cls( + data=result, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + + return Parameter.__torch_function__(func, types, args, kwargs) + + +# pylint: disable=protected-access +def apply_bnb_torch_function_patch(): + """ + Patch Params4bit.__torch_function__ using Axolotl-style approach. + + Returns: + True if patching succeeded, False otherwise. + """ + from bitsandbytes.nn.modules import Params4bit + + Params4bit.__torch_function__ = classmethod(patched_torch_function) + + LOG.info("Successfully patched Params4bit.__torch_function__") + + +# pylint: disable=protected-access +def apply_init_sharded_param_patch(): + """Apply patch to FSDPParam._init_sharded_param to support Params4bit.""" + from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam + + # Get original source + original_source = inspect.getsource(FSDPParam._init_sharded_param) + original_source, _ = detab_code(original_source) + + # Define the replacement + original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) + self.sharded_param.requires_grad_(param.requires_grad)""" + + patched_param_creation = """ import bitsandbytes as bnb + if isinstance(param, bnb.nn.modules.Params4bit): + self.sharded_param = bnb.nn.modules.Params4bit( + data=sharded_param, + requires_grad=param.requires_grad, + quant_state=param.quant_state, + blocksize=param.blocksize, + compress_statistics=param.compress_statistics, + quant_type=param.quant_type, + quant_storage=param.quant_storage, + module=param.module, + bnb_quantized=param.bnb_quantized, + ) + self.sharded_param = self.to_sharded_dtensor(self.sharded_param) + else: + self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) + self.sharded_param.requires_grad_(param.requires_grad)""" + + # Apply the replacement + if original_param_creation in original_source: + patched_source = original_source.replace( + original_param_creation, patched_param_creation + ) + patched_source = patched_source.replace( + "def _init_sharded_param(", + "def patched_init_sharded_param(", + 1, + ) + + # Load necessary imports + module_name = FSDPParam.__module__ + module = importlib.import_module(module_name) + + items_to_import = [] + for item in dir(module): + if item in patched_source: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102 + + # Replace the method + FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821 + LOG.info("Successfully applied FSDP _init_sharded_param patch") + else: + LOG.warning("Could not find target code for _init_sharded_param patching") + + +def apply_init_unsharded_param_patch(): + """Apply patch to FSDPParam.init_unsharded_param to support Params4bit.""" + from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam + + # Get original source + original_source = inspect.getsource(FSDPParam.init_unsharded_param) + original_source, _ = detab_code(original_source) + + # Define the replacement + original_param_creation = """ self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + )""" + + patched_param_creation = """ import bitsandbytes as bnb + local_tensor = self.sharded_param._local_tensor + if isinstance(local_tensor, bnb.nn.modules.Params4bit): + self._unsharded_param = bnb.nn.modules.Params4bit( + data=unsharded_param, + requires_grad=self.sharded_param.requires_grad, + quant_state=local_tensor.quant_state, + blocksize=local_tensor.blocksize, + compress_statistics=local_tensor.compress_statistics, + quant_type=local_tensor.quant_type, + quant_storage=local_tensor.quant_storage, + module=local_tensor.module, + bnb_quantized=local_tensor.bnb_quantized, + ) + else: + self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + )""" + + # Apply the replacement + if original_param_creation in original_source: + patched_source = original_source.replace( + original_param_creation, patched_param_creation + ) + patched_source = patched_source.replace( + "def init_unsharded_param(", + "def patched_init_unsharded_param(", + 1, + ) + + # Load necessary imports + module_name = FSDPParam.__module__ + module = importlib.import_module(module_name) + + items_to_import = [] + for item in dir(module): + if item in patched_source: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102 + + # Replace the method + FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821 + LOG.info("Successfully applied FSDP init_unsharded_param patch") + else: + LOG.warning("Could not find target code for patching") diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 02e80dd8e..61eec65d5 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -559,20 +559,6 @@ class LoRAValidationMixin: ) return data - @model_validator(mode="before") - @classmethod - def check_lora_8bit(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ): - if data.get("adapter") == "lora" and data.get("load_in_8bit"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA" - ) - return data - @model_validator(mode="before") @classmethod def check_lora_axolotl_unsloth(cls, data): @@ -619,7 +605,7 @@ class LoRAValidationMixin: @model_validator(mode="before") @classmethod - def check_lora_kernel_8bit(cls, data): + def check_lora_kernels_8bit(cls, data): if ( data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") @@ -627,20 +613,36 @@ class LoRAValidationMixin: ): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA" + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with 8-bit LoRA a the moment." ) return data @model_validator(mode="before") @classmethod - def check_lora_kernel_rl(cls, data): + def check_lora_kernels_dora(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("peft_use_dora"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with DoRA at the moment." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_kernels_rl(cls, data): if ( data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel") ) and data.get("rl"): raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment." + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with RL at the moment." ) return data diff --git a/tests/e2e/multigpu/test_fsdp2.py b/tests/e2e/multigpu/test_fsdp2.py index 95ced1303..0bb255266 100644 --- a/tests/e2e/multigpu/test_fsdp2.py +++ b/tests/e2e/multigpu/test_fsdp2.py @@ -174,6 +174,69 @@ class TestFSDP2: verify_training_success(temp_dir) + @require_torch_2_7_0 + def test_lora_sft_kernels(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + @require_torch_2_7_0 def test_qlora_sft(self, temp_dir): cfg = DictDefault( @@ -236,6 +299,70 @@ class TestFSDP2: verify_training_success(temp_dir) + @require_torch_2_7_0 + def test_qlora_sft_kernels(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + @require_torch_2_7_0 def test_dpo_fft(self, temp_dir): cfg = DictDefault( diff --git a/tests/e2e/patched/test_fsdp2_qlora.py b/tests/e2e/patched/test_fsdp2_qlora.py new file mode 100644 index 000000000..9dd053ad8 --- /dev/null +++ b/tests/e2e/patched/test_fsdp2_qlora.py @@ -0,0 +1,131 @@ +"""Integration tests for FSDP Params4bit patches.""" + +from unittest.mock import Mock, patch + +import bitsandbytes as bnb +import pytest +import torch +from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam + +from axolotl.monkeypatch.fsdp2_qlora import ( + apply_bnb_torch_function_patch, + patched_torch_function, +) + + +@pytest.fixture +def mock_params4bit(): + """Create a mock Params4bit instance with test attributes.""" + mock_instance = Mock() + mock_instance.requires_grad = True + mock_instance.quant_state = "test_state" + mock_instance.blocksize = 128 + mock_instance.compress_statistics = True + mock_instance.quant_type = "fp4" + mock_instance.quant_storage = "test_storage" + mock_instance.module = "test_module" + mock_instance.bnb_quantized = True + return mock_instance + + +class TestBnbTorchFunctionPatch: + """Test the Params4bit.__torch_function__ patch.""" + + def test_apply_patch(self): + """Test that the patch can be applied.""" + with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls: + apply_bnb_torch_function_patch() + assert hasattr(mock_cls, "__torch_function__") + assert isinstance(mock_cls.__torch_function__, classmethod) + + # pylint: disable=redefined-outer-name + def test_torch_chunk_preserves_attributes(self, mock_params4bit): + """Test that torch.chunk preserves Params4bit attributes.""" + mock_cls = Mock() + chunks = (torch.tensor([1, 2]), torch.tensor([3, 4])) + + with patch("torch.nn.Parameter.__torch_function__", return_value=chunks): + result = patched_torch_function( + mock_cls, + torch.chunk, + (type(mock_params4bit),), + args=(mock_params4bit, 2), + ) + + assert isinstance(result, tuple) + assert len(result) == 2 + + # Check that Params4bit constructor was called with preserved attributes + assert mock_cls.call_count == 2 + for call in mock_cls.call_args_list: + kwargs = call[1] + assert kwargs["requires_grad"] == mock_params4bit.requires_grad + assert kwargs["quant_state"] == mock_params4bit.quant_state + assert kwargs["blocksize"] == mock_params4bit.blocksize + + # pylint: disable=redefined-outer-name + def test_other_functions_fallback(self, mock_params4bit): + """Test that non-chunk/split functions use Parameter fallback.""" + mock_cls = Mock() + fallback_result = torch.tensor([5, 6, 7]) + + with patch( + "torch.nn.Parameter.__torch_function__", return_value=fallback_result + ) as mock_fallback: + result = patched_torch_function( + mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1) + ) + + # Should call Parameter.__torch_function__ and return its result + mock_fallback.assert_called_once() + assert result is fallback_result + mock_cls.assert_not_called() + + +class TestFSDPPatchIntegration: + """Test FSDP patch integration.""" + + @pytest.mark.integration + def test_all_patches_together(self): + """Test that all patches can be applied together.""" + from axolotl.monkeypatch.fsdp2_qlora import ( + apply_init_sharded_param_patch, + apply_init_unsharded_param_patch, + ) + + # Store original methods before patching + original_torch_function = getattr( + bnb.nn.modules.Params4bit, "__torch_function__", None + ) + + # pylint: disable=protected-access + original_init_sharded = FSDPParam._init_sharded_param + original_init_unsharded = FSDPParam.init_unsharded_param + + # Apply patches + apply_bnb_torch_function_patch() + apply_init_sharded_param_patch() + apply_init_unsharded_param_patch() + + # Verify patches were applied + current_torch_function = getattr( + bnb.nn.modules.Params4bit, "__torch_function__", None + ) + if original_torch_function is not None: + assert ( + current_torch_function != original_torch_function + ), "Params4bit.__torch_function__ was not patched" + else: + assert ( + current_torch_function is not None + ), "Params4bit.__torch_function__ was not added" + + # Check that FSDP methods were patched + assert ( + # pylint: disable=protected-access + FSDPParam._init_sharded_param + != original_init_sharded + ), "_init_sharded_param was not patched" + assert ( + FSDPParam.init_unsharded_param != original_init_unsharded + ), "init_unsharded_param was not patched"