upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064)

* upgrade transformers==4.55.1

* also upgrade bnb

* remove bnb params4bit patch (upstreamed)

* use latest causal-conv1d

* fix patching ring-flash-attn with now missing imports

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
Wing Lian
2025-08-13 19:41:07 -04:00
committed by GitHub
parent e0a2523a3b
commit 09145de8fa
8 changed files with 38 additions and 176 deletions

View File

@@ -37,7 +37,7 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.1
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1
@@ -14,7 +14,7 @@ packaging==23.2
huggingface_hub>=0.33.0
peft==0.17.0
transformers==4.55.0
transformers==4.55.1
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0

View File

@@ -285,12 +285,10 @@ class PatchManager:
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()

View File

@@ -9,73 +9,12 @@ Params4bit parameters.
import importlib
import inspect
import torch
from torch.nn import Parameter
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patched_torch_function(cls, func, types, args=(), kwargs=None):
"""
Patched version of Params4bit.__torch_function__ for preserving Params4bit
class identity and attributes.
"""
if kwargs is None:
kwargs = {}
if func in [torch.chunk, torch.split]:
tensor = args[0]
result = Parameter.__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)
return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
return Parameter.__torch_function__(func, types, args, kwargs)
# pylint: disable=protected-access
def apply_bnb_torch_function_patch():
"""
Patch Params4bit.__torch_function__ using Axolotl-style approach.
Returns:
True if patching succeeded, False otherwise.
"""
from bitsandbytes.nn.modules import Params4bit
Params4bit.__torch_function__ = classmethod(patched_torch_function)
LOG.info("Successfully patched Params4bit.__torch_function__")
# pylint: disable=protected-access
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""

View File

@@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

View File

@@ -15,12 +15,15 @@ import torch
import torch.distributed as dist
from torch.distributed import DeviceMesh
try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger

View File

@@ -3,6 +3,7 @@
# pylint: disable=too-many-boolean-expressions
import json
import sys
import tempfile
from pathlib import Path
@@ -1251,10 +1252,26 @@ class ComplexValidationMixin:
try:
import transformers.modeling_flash_attention_utils
from transformers.utils import is_flash_attn_greater_or_equal
# pylint: disable=protected-access
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
transformers.modeling_flash_attention_utils._flash_supports_window
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window_size",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"is_flash_attn_greater_or_equal",
is_flash_attn_greater_or_equal,
)
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:

View File

@@ -1,126 +1,28 @@
"""Integration tests for FSDP Params4bit patches."""
"""Integration tests for FSDP2 Params4bit patches."""
from unittest.mock import Mock, patch
import bitsandbytes as bnb
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
patched_torch_function,
)
@pytest.fixture
def mock_params4bit():
"""Create a mock Params4bit instance with test attributes."""
mock_instance = Mock()
mock_instance.requires_grad = True
mock_instance.quant_state = "test_state"
mock_instance.blocksize = 128
mock_instance.compress_statistics = True
mock_instance.quant_type = "fp4"
mock_instance.quant_storage = "test_storage"
mock_instance.module = "test_module"
mock_instance.bnb_quantized = True
return mock_instance
class TestBnbTorchFunctionPatch:
"""Test the Params4bit.__torch_function__ patch."""
def test_apply_patch(self):
"""Test that the patch can be applied."""
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
apply_bnb_torch_function_patch()
assert hasattr(mock_cls, "__torch_function__")
assert isinstance(mock_cls.__torch_function__, classmethod)
# pylint: disable=redefined-outer-name
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
"""Test that torch.chunk preserves Params4bit attributes."""
mock_cls = Mock()
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
result = patched_torch_function(
mock_cls,
torch.chunk,
(type(mock_params4bit),),
args=(mock_params4bit, 2),
)
assert isinstance(result, tuple)
assert len(result) == 2
# Check that Params4bit constructor was called with preserved attributes
assert mock_cls.call_count == 2
for call in mock_cls.call_args_list:
kwargs = call[1]
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
assert kwargs["quant_state"] == mock_params4bit.quant_state
assert kwargs["blocksize"] == mock_params4bit.blocksize
# pylint: disable=redefined-outer-name
def test_other_functions_fallback(self, mock_params4bit):
"""Test that non-chunk/split functions use Parameter fallback."""
mock_cls = Mock()
fallback_result = torch.tensor([5, 6, 7])
with patch(
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
) as mock_fallback:
result = patched_torch_function(
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
)
# Should call Parameter.__torch_function__ and return its result
mock_fallback.assert_called_once()
assert result is fallback_result
mock_cls.assert_not_called()
class TestFSDPPatchIntegration:
"""Test FSDP patch integration."""
@pytest.mark.integration
def test_all_patches_together(self):
def test_fsdp2_init_patches(self):
"""Test that all patches can be applied together."""
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
# Store original methods before patching
original_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
# pylint: disable=protected-access
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
# Verify patches were applied
current_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
if original_torch_function is not None:
assert (
current_torch_function != original_torch_function
), "Params4bit.__torch_function__ was not patched"
else:
assert (
current_torch_function is not None
), "Params4bit.__torch_function__ was not added"
# Check that FSDP methods were patched
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param