From 09145de8fa0306c3b88212da71564d3e3892ad31 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Aug 2025 19:41:07 -0400 Subject: [PATCH] upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064) * upgrade transformers==4.55.1 * also upgrade bnb * remove bnb params4bit patch (upstreamed) * use latest causal-conv1d * fix patching ring-flash-attn with now missing imports --------- Co-authored-by: Dan Saunders --- docker/Dockerfile-base | 2 +- requirements.txt | 4 +- src/axolotl/loaders/patch_manager.py | 2 - src/axolotl/monkeypatch/fsdp2_qlora.py | 61 ----------- .../monkeypatch/ring_attn/adapters/batch.py | 11 +- src/axolotl/monkeypatch/ring_attn/patch.py | 11 +- src/axolotl/utils/schemas/validation.py | 21 +++- tests/e2e/patched/test_fsdp2_qlora.py | 102 +----------------- 8 files changed, 38 insertions(+), 176 deletions(-) diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 0434a583f..d1151cedd 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -37,7 +37,7 @@ WORKDIR /workspace RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ - python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ + CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ python3 -m pip cache purge diff --git a/requirements.txt b/requirements.txt index 370bf5a5e..5f7767812 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ # START section of dependencies that don't install on Darwin/MacOS -bitsandbytes==0.46.1 +bitsandbytes==0.47.0 # triton 3.4.0 is not compatible with CCE triton>=3.0.0,<3.4.0 mamba-ssm==1.2.0.post1 @@ -14,7 +14,7 @@ packaging==23.2 huggingface_hub>=0.33.0 peft==0.17.0 -transformers==4.55.0 +transformers==4.55.1 tokenizers>=0.21.1 accelerate==1.10.0 datasets==4.0.0 diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index f1ca3c725..628d897d0 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -285,12 +285,10 @@ class PatchManager: and self.cfg.adapter == "qlora" ): from axolotl.monkeypatch.fsdp2_qlora import ( - apply_bnb_torch_function_patch, apply_init_sharded_param_patch, apply_init_unsharded_param_patch, ) - apply_bnb_torch_function_patch() apply_init_sharded_param_patch() apply_init_unsharded_param_patch() diff --git a/src/axolotl/monkeypatch/fsdp2_qlora.py b/src/axolotl/monkeypatch/fsdp2_qlora.py index a2cb7e472..5a4332fff 100644 --- a/src/axolotl/monkeypatch/fsdp2_qlora.py +++ b/src/axolotl/monkeypatch/fsdp2_qlora.py @@ -9,73 +9,12 @@ Params4bit parameters. import importlib import inspect -import torch -from torch.nn import Parameter - from axolotl.monkeypatch.utils import detab_code from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -def patched_torch_function(cls, func, types, args=(), kwargs=None): - """ - Patched version of Params4bit.__torch_function__ for preserving Params4bit - class identity and attributes. - """ - if kwargs is None: - kwargs = {} - - if func in [torch.chunk, torch.split]: - tensor = args[0] - result = Parameter.__torch_function__(func, types, args, kwargs) - - if isinstance(result, tuple): - return tuple( - cls( - data=chunk, - requires_grad=tensor.requires_grad, - quant_state=tensor.quant_state, - blocksize=tensor.blocksize, - compress_statistics=tensor.compress_statistics, - quant_type=tensor.quant_type, - quant_storage=tensor.quant_storage, - module=tensor.module, - bnb_quantized=tensor.bnb_quantized, - ) - for chunk in result - ) - - return cls( - data=result, - requires_grad=tensor.requires_grad, - quant_state=tensor.quant_state, - blocksize=tensor.blocksize, - compress_statistics=tensor.compress_statistics, - quant_type=tensor.quant_type, - quant_storage=tensor.quant_storage, - module=tensor.module, - bnb_quantized=tensor.bnb_quantized, - ) - - return Parameter.__torch_function__(func, types, args, kwargs) - - -# pylint: disable=protected-access -def apply_bnb_torch_function_patch(): - """ - Patch Params4bit.__torch_function__ using Axolotl-style approach. - - Returns: - True if patching succeeded, False otherwise. - """ - from bitsandbytes.nn.modules import Params4bit - - Params4bit.__torch_function__ = classmethod(patched_torch_function) - - LOG.info("Successfully patched Params4bit.__torch_function__") - - # pylint: disable=protected-access def apply_init_sharded_param_patch(): """Apply patch to FSDPParam._init_sharded_param to support Params4bit.""" diff --git a/src/axolotl/monkeypatch/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py index ebed9ebdc..607b4dd71 100644 --- a/src/axolotl/monkeypatch/ring_attn/adapters/batch.py +++ b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py @@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func from ring_flash_attn.adapters.hf_adapter import check_params from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal -try: +try: # pylint: disable=duplicate-code from transformers.modeling_flash_attention_utils import _flash_supports_window except ImportError: - from transformers.modeling_flash_attention_utils import ( - _flash_supports_window_size as _flash_supports_window, - ) + try: + from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size as _flash_supports_window, + ) + except ImportError: + _flash_supports_window = True from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 934687a16..ea0f9dd02 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -15,12 +15,15 @@ import torch import torch.distributed as dist from torch.distributed import DeviceMesh -try: +try: # pylint: disable=duplicate-code from transformers.modeling_flash_attention_utils import _flash_supports_window except ImportError: - from transformers.modeling_flash_attention_utils import ( - _flash_supports_window_size as _flash_supports_window, - ) + try: + from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size as _flash_supports_window, + ) + except ImportError: + _flash_supports_window = True from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.utils.logging import get_logger diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 72991c947..0d6d05a0e 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -3,6 +3,7 @@ # pylint: disable=too-many-boolean-expressions import json +import sys import tempfile from pathlib import Path @@ -1251,10 +1252,26 @@ class ComplexValidationMixin: try: import transformers.modeling_flash_attention_utils + from transformers.utils import is_flash_attn_greater_or_equal # pylint: disable=protected-access - transformers.modeling_flash_attention_utils._flash_supports_window_size = ( - transformers.modeling_flash_attention_utils._flash_supports_window + transformers.modeling_flash_attention_utils._flash_supports_window = ( + True + ) + setattr( + sys.modules["transformers.modeling_flash_attention_utils"], + "_flash_supports_window", + True, + ) + setattr( + sys.modules["transformers.modeling_flash_attention_utils"], + "_flash_supports_window_size", + True, + ) + setattr( + sys.modules["transformers.modeling_flash_attention_utils"], + "is_flash_attn_greater_or_equal", + is_flash_attn_greater_or_equal, ) import ring_flash_attn # noqa: F401 # pylint:disable=unused-import except ImportError as exception: diff --git a/tests/e2e/patched/test_fsdp2_qlora.py b/tests/e2e/patched/test_fsdp2_qlora.py index 9dd053ad8..ca17b81d1 100644 --- a/tests/e2e/patched/test_fsdp2_qlora.py +++ b/tests/e2e/patched/test_fsdp2_qlora.py @@ -1,126 +1,28 @@ -"""Integration tests for FSDP Params4bit patches.""" +"""Integration tests for FSDP2 Params4bit patches.""" -from unittest.mock import Mock, patch - -import bitsandbytes as bnb import pytest -import torch from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam -from axolotl.monkeypatch.fsdp2_qlora import ( - apply_bnb_torch_function_patch, - patched_torch_function, -) - - -@pytest.fixture -def mock_params4bit(): - """Create a mock Params4bit instance with test attributes.""" - mock_instance = Mock() - mock_instance.requires_grad = True - mock_instance.quant_state = "test_state" - mock_instance.blocksize = 128 - mock_instance.compress_statistics = True - mock_instance.quant_type = "fp4" - mock_instance.quant_storage = "test_storage" - mock_instance.module = "test_module" - mock_instance.bnb_quantized = True - return mock_instance - - -class TestBnbTorchFunctionPatch: - """Test the Params4bit.__torch_function__ patch.""" - - def test_apply_patch(self): - """Test that the patch can be applied.""" - with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls: - apply_bnb_torch_function_patch() - assert hasattr(mock_cls, "__torch_function__") - assert isinstance(mock_cls.__torch_function__, classmethod) - - # pylint: disable=redefined-outer-name - def test_torch_chunk_preserves_attributes(self, mock_params4bit): - """Test that torch.chunk preserves Params4bit attributes.""" - mock_cls = Mock() - chunks = (torch.tensor([1, 2]), torch.tensor([3, 4])) - - with patch("torch.nn.Parameter.__torch_function__", return_value=chunks): - result = patched_torch_function( - mock_cls, - torch.chunk, - (type(mock_params4bit),), - args=(mock_params4bit, 2), - ) - - assert isinstance(result, tuple) - assert len(result) == 2 - - # Check that Params4bit constructor was called with preserved attributes - assert mock_cls.call_count == 2 - for call in mock_cls.call_args_list: - kwargs = call[1] - assert kwargs["requires_grad"] == mock_params4bit.requires_grad - assert kwargs["quant_state"] == mock_params4bit.quant_state - assert kwargs["blocksize"] == mock_params4bit.blocksize - - # pylint: disable=redefined-outer-name - def test_other_functions_fallback(self, mock_params4bit): - """Test that non-chunk/split functions use Parameter fallback.""" - mock_cls = Mock() - fallback_result = torch.tensor([5, 6, 7]) - - with patch( - "torch.nn.Parameter.__torch_function__", return_value=fallback_result - ) as mock_fallback: - result = patched_torch_function( - mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1) - ) - - # Should call Parameter.__torch_function__ and return its result - mock_fallback.assert_called_once() - assert result is fallback_result - mock_cls.assert_not_called() - class TestFSDPPatchIntegration: """Test FSDP patch integration.""" @pytest.mark.integration - def test_all_patches_together(self): + def test_fsdp2_init_patches(self): """Test that all patches can be applied together.""" from axolotl.monkeypatch.fsdp2_qlora import ( apply_init_sharded_param_patch, apply_init_unsharded_param_patch, ) - # Store original methods before patching - original_torch_function = getattr( - bnb.nn.modules.Params4bit, "__torch_function__", None - ) - # pylint: disable=protected-access original_init_sharded = FSDPParam._init_sharded_param original_init_unsharded = FSDPParam.init_unsharded_param # Apply patches - apply_bnb_torch_function_patch() apply_init_sharded_param_patch() apply_init_unsharded_param_patch() - # Verify patches were applied - current_torch_function = getattr( - bnb.nn.modules.Params4bit, "__torch_function__", None - ) - if original_torch_function is not None: - assert ( - current_torch_function != original_torch_function - ), "Params4bit.__torch_function__ was not patched" - else: - assert ( - current_torch_function is not None - ), "Params4bit.__torch_function__ was not added" - - # Check that FSDP methods were patched assert ( # pylint: disable=protected-access FSDPParam._init_sharded_param